1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| def train(net, train_iter, valid_iter, num_epochs, lr, weight_decay, lr_period, lr_decay, devices): optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer=optim, T_max=num_epochs, eta_min=1e-6 )
timer, num_batches = d2l.torch.Timer(), len(train_iter) legend = ['train loss', 'train acc'] if valid_iter is not None: legend.append('valid acc') animator = d2l.torch.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=legend) net = nn.DataParallel(module=net, device_ids=devices).to(devices[0]) timer_epoch = d2l.torch.Timer()
best_valid_acc = 0.0
for epoch in range(num_epochs): timer_epoch.start() accumulator = d2l.torch.Accumulator(3) net.train() for i, (X, y) in enumerate(train_iter): timer.start() batch_loss, batch_acc = d2l.torch.train_batch_ch13(net, X, y, loss, optim, devices) accumulator.add(batch_loss, batch_acc, y.shape[0]) timer.stop() if i % (num_batches // 5) == 0 or i == num_batches - 1: animator.add(epoch + (i + 1) / num_batches, (accumulator[0] / accumulator[2], accumulator[1] / accumulator[2], None)) timer_epoch.stop() net.eval() measures = f'train loss {accumulator[0] / accumulator[2]}, train acc {accumulator[1] / accumulator[2]},\n' if valid_iter is not None: valid_acc = d2l.torch.evaluate_accuracy_gpu(net, valid_iter, devices[0]) animator.add(epoch + 1, (None, None, valid_acc)) measures += f'valid acc {valid_acc},' if valid_acc > best_valid_acc: best_valid_acc = valid_acc torch.save(net.state_dict(), 'best_model.pth') print(f'Saved the best model at epoch {epoch + 1}, valid acc {valid_acc:.3f}') lr_scheduler.step() print( measures + f'\n{num_epochs * accumulator[2] / timer.sum()} examples/sec and {timer_epoch.avg():.1f}秒/轮, on {str(devices[0])}')
|