我的训练精度是 1.0,但对训练数据的预测是错误的

数据挖掘 准确性
2022-03-12 01:28:56

我的神经网络工作不正常,我正试图找出发生了什么。

我只将三张图像插入到迁移学习(mobilenet)神经网络中。这三个图像的类是:array([[0., 0., 0., 1.], [0., 1., 0., 0.], [0., 1., 0., 0. ]])

我在这些图片上做了 50 个 epoch,到了第 20 个 epoch 左右,训练精度保持在 1.0:

纪元 50/50 3/3 [===============================] - 6s 2s/step - loss: 1.3671 - acc :1.0000 - val_loss:1.3770 - val_acc:0.0000e+00

然后当我去预测相同的三个图像的结果时,如下所示: predictions_test_2 = model_mn.predict(X, batch_size=1, verbose=1)

预测为:阵列([[[0.2473848,0.25099277,0.251868,0.24975444],[0.24154082,0.25245225,0.25358915,0.2524177]

如果训练精度是 1.0,那怎么可能?!

这是代码: def mobilenet(img_rows, img_cols, channel=1, num_classes=None):

model = MobileNet( include_top=True,weights='imagenet')

model.layers.pop()

model.outputs = [model.layers[-1].output]

model.layers[-1].outbound_nodes = []

x=Dense(num_classes, activation='softmax')(model.output)

model=Model(model.input,x)

#To set the first 8 layers to non-trainable (weights will not be updated)

for layer in model.layers[:8]:

   layer.trainable = False
model_new = Sequential()
for layer in model.layers[:-1]: # just exclude last layer from copying
    model_new.add(layer)
model=model_new
model.add(Dense(256,activation='relu',input_shape=(1000,)))

model.add(Dense(64, activation='relu'))
model.add(Dense(4,activation='softmax'))


# Learning rate is changed to 0.001
sgd = SGD(lr=1e-6,decay=1e-1,momentum=0.95, nesterov=True)
adam=Adam(lr=1e-6, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0000001, amsgrad=True)
model.compile(optimizer=adam, loss='categorical_crossentropy',metrics=['accuracy'])


# checkpoint
filepath="weights-improvement-mn-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
return model

model_mn = mobilenet(img_rows, img_cols, channel, num_classes)

model_mn.fit(X, Y,batch_size=3,epochs=50,shuffle=True,verbose=1,validation_data=(X_vall, Y_vall))

1个回答

很有可能,你的数据过拟合了,你需要密切监控验证的准确性,好像它与训练的偏差太大,然后你就进入了过拟合的领域。

此外,仅使用 3 张图片进行训练也太少了,您的网络将无法通过如此小的输入正确泛化。