检查点如何在训练期间找到最佳模型?

数据挖掘 Python 深度学习 喀拉斯 张量流 朱庇特
2022-03-08 01:36:33
  1. filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"我得到关键错误 val_acc(我使用 tensorflow 1.14.0)。
  2. filepath="weights.best.hdf5"和时save_best_only=True,它检查训练期间观察到的最佳模型,但是由于准确度没有增加,它未能保存模型,所以这是否意味着我需要增加 epochs。还有为什么它不考虑可用的准确度分数并选择最高分数作为最佳模型并保存。

示例代码

# load pima indians dataset
dataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# create model
model = Sequential()
model.add(Dense(12, input_dim=8, init= "uniform" , activation= "relu" ))
model.add(Dense(8, init= "uniform" , activation= "relu" ))
model.add(Dense(1, init= "uniform" , activation= "sigmoid" ))
# Compile model
model.compile(loss= "binary_crossentropy" , optimizer= "adam" , metrics=[ "accuracy" ])
# checkpoint
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor= "val_acc" , verbose=1, save_best_only=True,
mode= "max" )
callbacks_list = [checkpoint]
# Fit the model
model.fit(X, Y, validation_split=0.33, nb_epoch=50, batch_size=10,
callbacks=callbacks_list, verbose=2)

Train on 514 samples, validate on 254 samples
Epoch 1/50
 - 0s - loss: 0.6846 - accuracy: 0.6401 - val_loss: 0.6691 - val_accuracy: 0.6732
Epoch 2/50
 - 0s - loss: 0.6672 - accuracy: 0.6401 - val_loss: 0.6517 - val_accuracy: 0.6732
Epoch 3/50
 - 0s - loss: 0.6600 - accuracy: 0.6498 - val_loss: 0.6503 - val_accuracy: 0.6772
Epoch 4/50
 - 0s - loss: 0.6529 - accuracy: 0.6440 - val_loss: 0.6401 - val_accuracy: 0.6811
Epoch 5/50
 - 0s - loss: 0.6476 - accuracy: 0.6654 - val_loss: 0.6346 - val_accuracy: 0.6772
Epoch 6/50
 - 0s - loss: 0.6385 - accuracy: 0.6459 - val_loss: 0.6448 - val_accuracy: 0.6299
Epoch 7/50
 - 0s - loss: 0.6334 - accuracy: 0.6615 - val_loss: 0.6242 - val_accuracy: 0.6772
Epoch 8/50
 - 0s - loss: 0.6261 - accuracy: 0.6498 - val_loss: 0.6166 - val_accuracy: 0.6693
Epoch 9/50
 - 0s - loss: 0.6216 - accuracy: 0.6673 - val_loss: 0.6057 - val_accuracy: 0.6969
Epoch 10/50
 - 0s - loss: 0.6192 - accuracy: 0.6673 - val_loss: 0.6059 - val_accuracy: 0.6654
Epoch 11/50
 - 0s - loss: 0.6247 - accuracy: 0.6595 - val_loss: 0.5972 - val_accuracy: 0.6811
Epoch 12/50
 - 0s - loss: 0.6139 - accuracy: 0.6518 - val_loss: 0.5936 - val_accuracy: 0.6811
Epoch 13/50
 - 0s - loss: 0.6107 - accuracy: 0.6732 - val_loss: 0.5908 - val_accuracy: 0.6772
Epoch 14/50
 - 0s - loss: 0.6093 - accuracy: 0.6770 - val_loss: 0.5848 - val_accuracy: 0.6929
Epoch 15/50
 - 0s - loss: 0.6001 - accuracy: 0.6829 - val_loss: 0.5866 - val_accuracy: 0.6772
Epoch 16/50
 - 0s - loss: 0.6022 - accuracy: 0.6829 - val_loss: 0.5804 - val_accuracy: 0.7008
Epoch 17/50
 - 0s - loss: 0.5957 - accuracy: 0.6790 - val_loss: 0.5990 - val_accuracy: 0.6811
Epoch 18/50
 - 0s - loss: 0.5911 - accuracy: 0.6887 - val_loss: 0.6046 - val_accuracy: 0.6575
Epoch 19/50
 - 0s - loss: 0.6028 - accuracy: 0.6770 - val_loss: 0.5706 - val_accuracy: 0.7008
Epoch 20/50
 - 0s - loss: 0.6086 - accuracy: 0.6673 - val_loss: 0.5790 - val_accuracy: 0.7205
Epoch 21/50
 - 0s - loss: 0.5904 - accuracy: 0.6965 - val_loss: 0.5636 - val_accuracy: 0.7008
Epoch 22/50
 - 0s - loss: 0.5931 - accuracy: 0.6965 - val_loss: 0.6001 - val_accuracy: 0.6654
Epoch 23/50
 - 0s - loss: 0.5895 - accuracy: 0.7023 - val_loss: 0.5647 - val_accuracy: 0.7087
Epoch 24/50
 - 0s - loss: 0.5837 - accuracy: 0.7101 - val_loss: 0.5628 - val_accuracy: 0.7283
Epoch 25/50
 - 0s - loss: 0.5837 - accuracy: 0.6965 - val_loss: 0.5584 - val_accuracy: 0.7047
Epoch 26/50
 - 0s - loss: 0.5828 - accuracy: 0.6887 - val_loss: 0.5593 - val_accuracy: 0.7362
Epoch 27/50
 - 0s - loss: 0.5913 - accuracy: 0.6965 - val_loss: 0.5580 - val_accuracy: 0.6890
Epoch 28/50
 - 0s - loss: 0.5861 - accuracy: 0.6965 - val_loss: 0.5597 - val_accuracy: 0.7441
Epoch 29/50
 - 0s - loss: 0.5846 - accuracy: 0.6887 - val_loss: 0.5584 - val_accuracy: 0.6850
Epoch 30/50
 - 0s - loss: 0.5780 - accuracy: 0.6946 - val_loss: 0.5553 - val_accuracy: 0.7165
Epoch 31/50
 - 0s - loss: 0.5802 - accuracy: 0.6887 - val_loss: 0.5619 - val_accuracy: 0.7323
Epoch 32/50
 - 0s - loss: 0.5816 - accuracy: 0.6965 - val_loss: 0.5574 - val_accuracy: 0.6929
Epoch 33/50
 - 0s - loss: 0.5740 - accuracy: 0.7140 - val_loss: 0.5540 - val_accuracy: 0.6850
Epoch 34/50
 - 0s - loss: 0.5723 - accuracy: 0.7004 - val_loss: 0.5523 - val_accuracy: 0.7323
Epoch 35/50
 - 0s - loss: 0.5746 - accuracy: 0.6965 - val_loss: 0.5645 - val_accuracy: 0.7323
Epoch 36/50
 - 0s - loss: 0.5684 - accuracy: 0.7004 - val_loss: 0.5664 - val_accuracy: 0.7323
Epoch 37/50
 - 0s - loss: 0.5726 - accuracy: 0.7140 - val_loss: 0.5492 - val_accuracy: 0.7087
Epoch 38/50
 - 0s - loss: 0.5628 - accuracy: 0.7218 - val_loss: 0.5933 - val_accuracy: 0.6850
Epoch 39/50
 - 0s - loss: 0.5832 - accuracy: 0.6965 - val_loss: 0.5487 - val_accuracy: 0.7087
Epoch 40/50
 - 0s - loss: 0.5604 - accuracy: 0.7140 - val_loss: 0.5780 - val_accuracy: 0.7087
Epoch 41/50
 - 0s - loss: 0.5714 - accuracy: 0.7082 - val_loss: 0.5590 - val_accuracy: 0.6929
Epoch 42/50
 - 0s - loss: 0.5745 - accuracy: 0.7004 - val_loss: 0.5550 - val_accuracy: 0.6929
Epoch 43/50
 - 0s - loss: 0.5623 - accuracy: 0.7140 - val_loss: 0.5497 - val_accuracy: 0.7047
Epoch 44/50
 - 0s - loss: 0.5697 - accuracy: 0.7062 - val_loss: 0.5497 - val_accuracy: 0.7205
Epoch 45/50
 - 0s - loss: 0.5618 - accuracy: 0.7140 - val_loss: 0.5485 - val_accuracy: 0.7205
Epoch 46/50
 - 0s - loss: 0.5614 - accuracy: 0.7121 - val_loss: 0.5457 - val_accuracy: 0.7126
Epoch 47/50
 - 0s - loss: 0.5587 - accuracy: 0.7140 - val_loss: 0.5548 - val_accuracy: 0.7205
Epoch 48/50
 - 0s - loss: 0.5584 - accuracy: 0.6984 - val_loss: 0.5489 - val_accuracy: 0.7323
Epoch 49/50
 - 0s - loss: 0.5619 - accuracy: 0.6984 - val_loss: 0.5561 - val_accuracy: 0.6969
Epoch 50/50
 - 0s - loss: 0.5813 - accuracy: 0.7043 - val_loss: 0.5551 - val_accuracy: 0.7165
2个回答

对于问题 1,替换val_acc为,val_accuracy因为该度量被命名为accuracy这也可能解决您的第二个问题。

...

filepath="weights-improvement-{epoch:02d}-{val_accuracy:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor= "val_accuracy" , verbose=1, save_best_only=True,
mode= "max" )

...

对于问题 1,您可能正在使用f-string但未f在引号前指定:

filepath=f"weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5

对于问题 2,根据您的输出,我同意下面的@spb,monitor='val_accuracy'如果这是您要使用的指标,您需要更改希望您监控的内容。

如果你需要改进你的模型,有很多现有的答案可以解决尝试和改进现有模型的方法。