| 定义基础类 BaseModel class BaseModel: def __init__(self, args, net, weights): self.args = args self.weights = weights self.net = net.to(args.device)
def load(self): try: if os.path.exists(self.weights): print('Loading classifier...') self.net.load_state_dict(torch.load(self.weights)) except: print('Failed to load pre-trained network')
def save(self): print('Saving model to "{}"'.format(self.weights)) torch.save(self.net.state_dict(), self.weights) 继承该class: class Model(BaseModel): def __init__(self, args, net): super(Model, self).__init__(args, net, args.net_name)
def train(self): train_loss和validation_loss根据epochs数量初始化 train_loss = np.zeros(self.args.epochs)
for epoch in range(self.args.epochs): self.net.train() 学习率迭代 scheduler.step() for index, (images, labels) in enumerate(train_loader): images, labels = images.to(self.args.device), labels.to(self.args.device) optimizer.zero_grad() output = self.net(images) loss = criterion(output, labels) loss.backward() optimizer.step() _, predicted = torch.max(output.data, 1) correct += torch.sum(predicted == labels) train_loss[epoch] += loss.item() * self.args.batch_size total += len(images) if index > 0 and index % self.args.log_interval == 0: print('Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {:.2f}%'.format( epoch + 1, self.args.epochs, index * len(images), len(train_loader.dataset), 100. * index / len(train_loader), loss.item(), 100 * correct / total )) train_loss[epoch] = train_loss[epoch] / total validate_loss[epoch], validate_acc = self.validate() if validate_acc > best_acc: best_acc = validate_acc self.save() print('\nEnd of training, best class accuracy: {}'.format(best_acc)) plt.plot(train_loss,'b-',label='train_loss') plt.plot(validate_loss,'r--',label='validate_loss')
plt.xlabel('epoches') plt.title("Loss curve") plt.legend() plt.savefig(self.args.net_name+'.jpg') def validate(self, verbose=False): self.net.eval()
loss = 0 correct = 0 dataset=TrainDataset(path=self.args.dataset_path + VALIDATION_PATH, stride=self.args.validate_stride) validate_loader = DataLoader( dataset=dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=0 ) criterion = nn.CrossEntropyLoss(size_average=False) print('\nEvaluating classifier')
with torch.no_grad(): for images, labels in validate_loader: images = images.to(self.args.device) output = self.net(images)
loss += criterion(output.cpu(), labels).item() _, predicted = torch.max(output.cpu(), 1) correct += (predicted == labels).sum().item() loss /= len(dataset) acc = correct / len(dataset)
print('Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( loss, correct, len(dataset), 100. * correct / len(dataset) ))
return loss, acc def ctest(self, path, augment=False): self.load() self.net.eval() dataset = cTestDataset(path=path, stride=self.args.test_stride, augment=augment) data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False) stime = datetime.datetime.now() print('\t sum\t\t max\t\t maj')
labels = [] maj_pred = [] prop_pred = np.zeros((len(dataset), len(LABELS)), dtype=float) maj_correct = 0 with torch.no_grad(): for index, (image, label) in enumerate(data_loader): image = image.squeeze() image = image.to(self.args.device)
output = self.net(image) _, predicted = torch.max(output.data, 1)
sum_prob = 3 - np.argmax(np.sum(np.exp(output.data.cpu().numpy()), axis=0)[::-1]) max_prob = 3 - np.argmax(np.max(np.exp(output.data.cpu().numpy()), axis=0)[::-1]) maj_prob = 3 - np.argmax(np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0)[::-1])
labels.append(label.item()) maj_pred.append(maj_prob) prop_pred[index, :] = np.sum(np.eye(4)[np.array(predicted).reshape(-1)], axis=0) / len(image)
print('{}) \t {} \t {} \t {} \t {}'.format( str(index + 1).rjust(2, '0'), LABELS[sum_prob].ljust(8), LABELS[max_prob].ljust(8), LABELS[maj_prob].ljust(8), str(label.data))) maj_correct += torch.sum(maj_prob == label).item()
maj_acc = maj_correct / len(dataset) print('Maj_image accuracy: {}/{} ({:.2f}%)'.format( maj_correct, len(data_loader.dataset), 100. * maj_acc )) print('\nInference time: {}\n'.format(datetime.datetime.now() - stime))
return prop_pred, labels