Keras回调示例,用于在每个时期后保存模型?

数据挖掘 Python 喀拉斯
2021-10-06 22:38:47

有人可以发布一个简单的 Keras 示例,使用回调在每个时期后保存模型吗?我可以找到保存权重的示例,但我希望能够在每个训练时期后保存一个功能齐全的模型。

3个回答

在 Keras回调'ModelCheckpoint' 中将'save_weights_only' 设置为 False将保存完整的模型;从上面的链接中获取的这个示例将在每个时期保存一个完整的模型,无论性能如何:

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

在这里可以找到更多示例,包括仅保存改进的模型和加载保存的模型。

确保在文件路径中包含 epoch 变量。否则,您保存的模型将在每个 epoch 后被替换。

filepath = "saved-model-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=False, mode='max')

有关更多示例,请查看此处

我写了自己的ModelCheckpoint类,因为我必须调用一个特殊的save_pretrained方法:

class ModelCheckpoint(Callback):

    def __init__(self, freq, directory):
        super().__init__()
        self.freq = freq
        self.directory = directory

    def on_epoch_begin(self, epoch, logs=None):
        if self.freq > 0 and epoch % self.freq == 0:
            self.model.save_pretrained(Path(directory, str(epoch)))

    def on_train_end(self, logs=None):
        self.model.save_pretrained(directory)

它总是在每个freq时期和训练结束时保存模型。