从检查点恢复,精度下降一个周期

数据挖掘 火炬
2022-03-08 10:23:41

我在尝试使用 PyTorch 微调模型时遇到了一个奇怪的问题。我改编了一个类似于imagenet.pyPyTorch 存储库中的训练脚本的脚本。每次我停止训练并尝试从检查点恢复时,我都会看到准确性急剧下降。在一个“保存周期”(迷你时代?)之后,准确性似乎恢复了,有时甚至做得更好。

由于我正在运行的数据集很大,我已经更改了日志并在更短的周期内保存,因此我的训练循环与原始 imagenet.py 脚本有点不同。这可能是这个错误的原因,但我无法弄清楚这可能是什么。

恢复训练时准确率下降

    import os 
    import shutil 
    import time

    import torch from tensorboard_logger 
    import log_value

    def train(train_dataset, train_loader, model, criterion, optimizer, val_loader, best_prec1, best_train_prec1, samples, checkpoint_directory, args, scheduler):
        """Train for one epoch on the training set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to train mode
        model.train()

        end = time.time()
        for i, (input, target) in enumerate(train_loader):
            if scheduler is not None:
                scheduler.batch_step()
            target = target.cuda(async=True)
            input = input.cuda()
            input_var = torch.autograd.Variable(input)
            target_var = torch.autograd.Variable(target)

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data[0], input.size(0))
            top1.update(prec1[0], input.size(0))
            samples += input.size(0)

            # compute gradient and do SGD step
            loss.backward()
            if i % args.accum == 0:
                optimizer.step()
                optimizer.zero_grad()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Epoch: {0:.4f}\t'
                      'Step: {1}/{2}\t'
                      'Samples: [{samples}]\t'
                      'LR: {lr}\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Samples/s {samples_per_sec:.0f}\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    samples / len(train_dataset), i, len(train_loader), samples=samples, batch_time=batch_time,
                    samples_per_sec=input.size(0)/batch_time.avg,
                    lr=get_learning_rate(optimizer)[0],# *iter_accum ???
                loss=losses, top1=top1))

            if i % args.save_steps_freq == 0:
                if i>0:
                    # evaluate on validation set
                    prec1 = validate(val_loader, model, criterion, samples, args)

                    # remember best prec@1 and save checkpoint
                    print('Checkpoint')
                    is_best = prec1 > best_prec1
                    best_prec1 = max(prec1, best_prec1)
                    is_best_train = top1.avg > best_train_prec1
                    best_train_prec1 = max(top1.avg, best_train_prec1)
                    save_checkpoint({
                        'samples': samples,
                        'state_dict': model.state_dict(),
                        'best_prec1': best_prec1,
                        'best_train_prec1': best_train_prec1,
                        'train_prec1': top1.avg,
                    }, is_best, is_best_train,
                        directory=checkpoint_directory
                    )

                    # log to TensorBoard
                    log_value('train_loss', losses.avg, samples)
                    log_value('train_acc', top1.avg, samples)
                    log_value('learning_rate', get_learning_rate(optimizer)[0], samples)
                    log_value('batch_size', input.size(0), samples)
                    log_value('effective_batch_size', input.size(0)*args.accum, samples)
                    log_value('accum', args.accum, samples)

                batch_time.reset()
                losses.reset()
                top1.reset()
        return best_prec1, best_train_prec1, samples


    def validate(val_loader, model, criterion, samples, args):
        """Perform validation on the validation set"""
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()

        # switch to evaluate mode
        model.eval()

        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            # print("input={}", input.size())
            target = target.cuda(async=True)
            input = input.cuda()
            input_var = torch.autograd.Variable(input, volatile=True)
            target_var = torch.autograd.Variable(target, volatile=True)

            # compute output
            output = model(input_var)
            # print("validate vars input={} target={} output={}".format(input_var.size(), target_var.size(), output.size()))
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target, topk=(1,))[0]
            losses.update(loss.data[0], input.size(0))
            top1.update(prec1[0], input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                    i, len(val_loader), batch_time=batch_time, loss=losses,
                    top1=top1))

        print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))

        # log to TensorBoard
        log_value('val_loss', losses.avg, samples)
        log_value('val_acc', top1.avg, samples)

        return top1.avg


    def get_learning_rate(optimizer):
        if optimizer is None:
            return [0.0]
        lr=[]
        for param_group in optimizer.param_groups:
            lr +=[ param_group['lr'] ]
        return lr


    def save_checkpoint(state, is_best, is_best_train, directory, filename='checkpoint.pth.tar'):
        """Saves checkpoint to disk"""
        if not os.path.exists(directory):
            os.makedirs(directory)
        filename = directory + filename
        torch.save(state, filename)
        if is_best:
            shutil.copyfile(filename, directory + 'model_best.pth.tar')
        if is_best_train:
            shutil.copyfile(filename, directory + 'model_best_train.pth.tar')


    class AverageMeter(object):
        """Computes and stores the average and current value"""

        def __init__(self):
            self.reset()

        def reset(self):
            self.val = 0
            self.avg = 0
            self.sum = 0
            self.count = 0

        def update(self, val, n=1):
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count


    def adjust_learning_rate(optimizer, epoch, lr):
        """Sets the learning rate to the initial LR decayed by 10 after 150 and 225 epochs"""
        lr = lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
        # log to TensorBoard
        log_value('learning_rate', lr, epoch)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


    def accuracy(output, target, topk=(1,)):
        """Computes the precision@k for the specified values of k"""
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
1个回答

我的直觉是,这是由于失去了优化器状态。您只保留模型权重,优化器必须重新学习动量。