我应该如何在 Pytorch 中保存模型权重?

数据挖掘 张量流 火炬
2022-03-01 12:35:03
### Create / load model

# Faster - RCNN Model - pretrained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor =  FastRCNNPredictor(in_features, num_classes)

训练模型不使用model.fit()函数,它使用循环。

    # let's train it for 10 epochs
    num_epochs = 10

    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        evaluate(model, data_loader_test, device=device)

如何保存模型?

1个回答

PyTorch 有一个state_dict存储模型(在本例中为神经网络)在任何时间点的状态。保存它需要将这些状态转储到一个文件中,这很容易完成:

torch.save(model.state_dict(), PATH)

重新加载模型时,请记住首先使用其默认权重创建模型类并从文件中加载状态字典。

这是保存/加载 pytorch 模块的链接(https://pytorch.org/tutorials/beginner/saving_loading_models.html)。