我的模型是否过拟合(LSTM、GRU)

数据挖掘 机器学习 分类 喀拉斯 nlp lstm
2022-03-17 03:35:15

我有小型语料库,最多 150 个文本话语,再次分布在 5 个类别中。为了测试,我从基本的深度学习模型开始,我使用 word2vec 嵌入,添加了 1D 卷积层,后跟 150 个 GRU 单元:

embedding_layer = Embedding(vocab_size, 300,weights=[embedding_matrix], input_length=max_length,trainable=True)
sequence_input = Input(shape=(max_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)

x = Conv1D(24, 4,activation='relu')(embedded_sequences)
x = GRU(150,kernel_regularizer=regularizers.l1(0.01),activation='tanh')(x)
x = Dropout(0.4)(x)

训练和验证损失:

训练和验证损失 根据这个,它非常适合,但是当我给出预测它会出现在其他班级时。

预测话语的代码:

encoded_doc = [[16, 7, 49, 50, 51]]
max_length = 25
padded_doc = pad_sequences(encoded_doc, maxlen=max_length, padding='post')
predictions = model.predict(padded_doc)
pre_class = model.predict(padded_doc)[0]
classes = np.argmax(predictions)

print('Predicted class: '+str(label_encoder.inverse_transform(classes))+' ## score: '+str(pre_class[classes]))

因此,我将其更改为 LSTM,将 LSTM 单元增加到 250:

seed = 7
np.random.seed(seed)

embedding_layer = Embedding(vocab_size, 300,weights=[embedding_matrix], input_length=max_length,trainable=True)
sequence_input = Input(shape=(max_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)

x = Conv1D(24, 4,activation='relu')(embedded_sequences)
x = LSTM(250,kernel_regularizer=regularizers.l1(0.01),activation='tanh')(x)
x = Dropout(0.4)(x)

训练和验证损失:

在此处输入图像描述

_________________________________________________________________ 
Layer (type)                 Output Shape              Param #   
================================================================= 
input_7 (InputLayer)         (None, 25)                0         
_________________________________________________________________                     
embedding_7 (Embedding)      (None, 25, 300)           31500     
_________________________________________________________________     
conv1d_7 (Conv1D)            (None, 22, 64)            76864     
_________________________________________________________________ 
lstm_4 (LSTM)                (None, 250)               315000    
_________________________________________________________________ 
dropout_7 (Dropout)          (None, 250)               0         
_________________________________________________________________ 
dense_7 (Dense)              (None, 5)                 1255      
================================================================= 
Total params: 424,619 Trainable params: 424,619 Non-trainable params:0
_________________________________________________________________ 
Train on 123 samples, validate on 31 samples Epoch 1/35 123/123
> [==============================] - 1s 8ms/step - loss: 22.8336 - acc:
> 0.2439 - val_loss: 18.5322 - val_acc: 0.3871 Epoch 2/35 123/123 [==============================] - 1s 7ms/step - loss: 16.5612 - acc:
> 0.2846 - val_loss: 12.7986 - val_acc: 0.4194 Epoch 3/35 123/123 [==============================] - 1s 7ms/step - loss: 11.0305 - acc:
> 0.3171 - val_loss: 7.7372 - val_acc: 0.5806 Epoch 4/35 123/123 [==============================] - 1s 7ms/step - loss: 6.5548 - acc:
> 0.4472 - val_loss: 4.1708 - val_acc: 0.5806 Epoch 5/35 123/123 [==============================] - 1s 7ms/step - loss: 3.2714 - acc:
> 0.5447 - val_loss: 1.9674 - val_acc: 0.5484 Epoch 6/35 123/123 [==============================] - 1s 7ms/step - loss: 1.6011 - acc:
> 0.5528 - val_loss: 1.2261 - val_acc: 0.5806 Epoch 7/35 123/123 [==============================] - 1s 7ms/step - loss: 2.0814 - acc:
> 0.4878 - val_loss: 1.9198 - val_acc: 0.5484 Epoch 8/35 123/123 [==============================] - 1s 7ms/step - loss: 1.7965 - acc:
> 0.4634 - val_loss: 1.2227 - val_acc: 0.5806 Epoch 9/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4348 - acc:
> 0.5447 - val_loss: 1.2684 - val_acc: 0.6129 Epoch 10/35 123/123 [==============================] - 1s 7ms/step - loss: 1.3092 - acc:
> 0.5772 - val_loss: 1.0482 - val_acc: 0.7742 Epoch 11/35 123/123 [==============================] - 1s 7ms/step - loss: 1.2495 - acc:
> 0.6341 - val_loss: 1.0036 - val_acc: 0.7419 Epoch 12/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1438 - acc:
> 0.7073 - val_loss: 0.9640 - val_acc: 0.7419 Epoch 13/35 123/123 [==============================] - 1s 7ms/step - loss: 0.8768 - acc:
> 0.9024 - val_loss: 0.4931 - val_acc: 1.0000 Epoch 14/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5908 - acc:
> 0.9512 - val_loss: 2.2134 - val_acc: 0.6129 Epoch 15/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4977 - acc:
> 0.6423 - val_loss: 1.1113 - val_acc: 0.5806 Epoch 16/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1936 - acc:
> 0.5691 - val_loss: 1.0105 - val_acc: 0.5806 Epoch 17/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1522 - acc:
> 0.4634 - val_loss: 1.0109 - val_acc: 0.5806 Epoch 18/35 123/123 [==============================] - 1s 7ms/step - loss: 1.0135 - acc:
> 0.6260 - val_loss: 2.5970 - val_acc: 0.1935 Epoch 19/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4423 - acc:
> 0.6992 - val_loss: 1.3537 - val_acc: 0.6774 Epoch 20/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4512 - acc:
> 0.5935 - val_loss: 1.1868 - val_acc: 0.6774 Epoch 21/35 123/123 [==============================] - 1s 7ms/step - loss: 1.2350 - acc:
> 0.5447 - val_loss: 1.0781 - val_acc: 0.5806 Epoch 22/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1008 - acc:
> 0.6260 - val_loss: 0.9849 - val_acc: 0.9677 Epoch 23/35 123/123 [==============================] - 1s 7ms/step - loss: 0.9986 - acc:
> 0.6504 - val_loss: 0.8684 - val_acc: 0.9355 Epoch 24/35 123/123 [==============================] - 1s 7ms/step - loss: 1.3619 - acc:
> 0.7154 - val_loss: 1.4444 - val_acc: 0.6774 Epoch 25/35 123/123 [==============================] - 1s 7ms/step - loss: 1.5590 - acc:
> 0.7398 - val_loss: 1.5238 - val_acc: 0.5806 Epoch 26/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1659 - acc:
> 0.8862 - val_loss: 0.8608 - val_acc: 1.0000 Epoch 27/35 123/123 [==============================] - 1s 7ms/step - loss: 0.8432 - acc:
> 0.9756 - val_loss: 0.6919 - val_acc: 1.0000 Epoch 28/35 123/123 [==============================] - 1s 8ms/step - loss: 0.7218 - acc:
> 0.9675 - val_loss: 0.6103 - val_acc: 1.0000 Epoch 29/35 123/123 [==============================] - 1s 7ms/step - loss: 0.6496 - acc:
> 0.9756 - val_loss: 0.5566 - val_acc: 1.0000 Epoch 30/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5961 - acc:
> 0.9675 - val_loss: 0.5152 - val_acc: 1.0000 Epoch 31/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5590 - acc:
> 0.9675 - val_loss: 0.4832 - val_acc: 1.0000 Epoch 32/35 123/123 [==============================] - 1s 9ms/step - loss: 0.5188 - acc:
> 0.9593 - val_loss: 0.4564 - val_acc: 1.0000 Epoch 33/35 123/123 [==============================] - 1s 10ms/step - loss: 0.4987 - acc:
> 0.9675 - val_loss: 0.4355 - val_acc: 1.0000 Epoch 34/35 123/123 [==============================] - 1s 7ms/step - loss: 0.4634 - acc:
> 0.9919 - val_loss: 0.4177 - val_acc: 1.0000 Epoch 35/35 123/123 [==============================] - 1s 7ms/step - loss: 0.4372 - acc:
> 1.0000 - val_loss: 0.4043 - val_acc: 1.0000 31/31 [==============================] - 0s 1ms/step Accuracy: 100.000000

我没有得到任何线索我哪里出错了,即使是为了避免我硬编码句子标记来预测。我想至少通过训练和验证损失图来了解它是否过度拟合,我认为不是。

如果您需要更多信息,请告诉我。


  • 时代:35
  • 批量:20
  • 随机播放 = 假
  • 验证拆分为 20%

提前致谢。

1个回答

尝试增加批量大小。

我认为这可能是因为您指定了 35 的批处理大小,并且默认情况下将在 32 的批处理大小上测试验证。在刚刚达到 epoch 的权重上测试一批 32 可能确实会导致稍微更好的平均性能,因为该批次中的所有样本都获得了最新和最佳的当前权重。

如果您的训练集中的样本和验证集中的样本非常相似,我希望验证曲线总是略高于训练分数,因为您使用相似的批量大小(35 对 32)。

您可以看到,随着时间的推移,训练曲线和验证曲线确实趋于平稳。