我有小型语料库,最多 150 个文本话语,再次分布在 5 个类别中。为了测试,我从基本的深度学习模型开始,我使用 word2vec 嵌入,添加了 1D 卷积层,后跟 150 个 GRU 单元:
embedding_layer = Embedding(vocab_size, 300,weights=[embedding_matrix], input_length=max_length,trainable=True)
sequence_input = Input(shape=(max_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(24, 4,activation='relu')(embedded_sequences)
x = GRU(150,kernel_regularizer=regularizers.l1(0.01),activation='tanh')(x)
x = Dropout(0.4)(x)
训练和验证损失:
根据这个,它非常适合,但是当我给出预测它会出现在其他班级时。
预测话语的代码:
encoded_doc = [[16, 7, 49, 50, 51]]
max_length = 25
padded_doc = pad_sequences(encoded_doc, maxlen=max_length, padding='post')
predictions = model.predict(padded_doc)
pre_class = model.predict(padded_doc)[0]
classes = np.argmax(predictions)
print('Predicted class: '+str(label_encoder.inverse_transform(classes))+' ## score: '+str(pre_class[classes]))
因此,我将其更改为 LSTM,将 LSTM 单元增加到 250:
seed = 7
np.random.seed(seed)
embedding_layer = Embedding(vocab_size, 300,weights=[embedding_matrix], input_length=max_length,trainable=True)
sequence_input = Input(shape=(max_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(24, 4,activation='relu')(embedded_sequences)
x = LSTM(250,kernel_regularizer=regularizers.l1(0.01),activation='tanh')(x)
x = Dropout(0.4)(x)
训练和验证损失:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_7 (InputLayer) (None, 25) 0
_________________________________________________________________
embedding_7 (Embedding) (None, 25, 300) 31500
_________________________________________________________________
conv1d_7 (Conv1D) (None, 22, 64) 76864
_________________________________________________________________
lstm_4 (LSTM) (None, 250) 315000
_________________________________________________________________
dropout_7 (Dropout) (None, 250) 0
_________________________________________________________________
dense_7 (Dense) (None, 5) 1255
=================================================================
Total params: 424,619 Trainable params: 424,619 Non-trainable params:0
_________________________________________________________________
Train on 123 samples, validate on 31 samples Epoch 1/35 123/123
> [==============================] - 1s 8ms/step - loss: 22.8336 - acc:
> 0.2439 - val_loss: 18.5322 - val_acc: 0.3871 Epoch 2/35 123/123 [==============================] - 1s 7ms/step - loss: 16.5612 - acc:
> 0.2846 - val_loss: 12.7986 - val_acc: 0.4194 Epoch 3/35 123/123 [==============================] - 1s 7ms/step - loss: 11.0305 - acc:
> 0.3171 - val_loss: 7.7372 - val_acc: 0.5806 Epoch 4/35 123/123 [==============================] - 1s 7ms/step - loss: 6.5548 - acc:
> 0.4472 - val_loss: 4.1708 - val_acc: 0.5806 Epoch 5/35 123/123 [==============================] - 1s 7ms/step - loss: 3.2714 - acc:
> 0.5447 - val_loss: 1.9674 - val_acc: 0.5484 Epoch 6/35 123/123 [==============================] - 1s 7ms/step - loss: 1.6011 - acc:
> 0.5528 - val_loss: 1.2261 - val_acc: 0.5806 Epoch 7/35 123/123 [==============================] - 1s 7ms/step - loss: 2.0814 - acc:
> 0.4878 - val_loss: 1.9198 - val_acc: 0.5484 Epoch 8/35 123/123 [==============================] - 1s 7ms/step - loss: 1.7965 - acc:
> 0.4634 - val_loss: 1.2227 - val_acc: 0.5806 Epoch 9/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4348 - acc:
> 0.5447 - val_loss: 1.2684 - val_acc: 0.6129 Epoch 10/35 123/123 [==============================] - 1s 7ms/step - loss: 1.3092 - acc:
> 0.5772 - val_loss: 1.0482 - val_acc: 0.7742 Epoch 11/35 123/123 [==============================] - 1s 7ms/step - loss: 1.2495 - acc:
> 0.6341 - val_loss: 1.0036 - val_acc: 0.7419 Epoch 12/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1438 - acc:
> 0.7073 - val_loss: 0.9640 - val_acc: 0.7419 Epoch 13/35 123/123 [==============================] - 1s 7ms/step - loss: 0.8768 - acc:
> 0.9024 - val_loss: 0.4931 - val_acc: 1.0000 Epoch 14/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5908 - acc:
> 0.9512 - val_loss: 2.2134 - val_acc: 0.6129 Epoch 15/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4977 - acc:
> 0.6423 - val_loss: 1.1113 - val_acc: 0.5806 Epoch 16/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1936 - acc:
> 0.5691 - val_loss: 1.0105 - val_acc: 0.5806 Epoch 17/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1522 - acc:
> 0.4634 - val_loss: 1.0109 - val_acc: 0.5806 Epoch 18/35 123/123 [==============================] - 1s 7ms/step - loss: 1.0135 - acc:
> 0.6260 - val_loss: 2.5970 - val_acc: 0.1935 Epoch 19/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4423 - acc:
> 0.6992 - val_loss: 1.3537 - val_acc: 0.6774 Epoch 20/35 123/123 [==============================] - 1s 7ms/step - loss: 1.4512 - acc:
> 0.5935 - val_loss: 1.1868 - val_acc: 0.6774 Epoch 21/35 123/123 [==============================] - 1s 7ms/step - loss: 1.2350 - acc:
> 0.5447 - val_loss: 1.0781 - val_acc: 0.5806 Epoch 22/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1008 - acc:
> 0.6260 - val_loss: 0.9849 - val_acc: 0.9677 Epoch 23/35 123/123 [==============================] - 1s 7ms/step - loss: 0.9986 - acc:
> 0.6504 - val_loss: 0.8684 - val_acc: 0.9355 Epoch 24/35 123/123 [==============================] - 1s 7ms/step - loss: 1.3619 - acc:
> 0.7154 - val_loss: 1.4444 - val_acc: 0.6774 Epoch 25/35 123/123 [==============================] - 1s 7ms/step - loss: 1.5590 - acc:
> 0.7398 - val_loss: 1.5238 - val_acc: 0.5806 Epoch 26/35 123/123 [==============================] - 1s 7ms/step - loss: 1.1659 - acc:
> 0.8862 - val_loss: 0.8608 - val_acc: 1.0000 Epoch 27/35 123/123 [==============================] - 1s 7ms/step - loss: 0.8432 - acc:
> 0.9756 - val_loss: 0.6919 - val_acc: 1.0000 Epoch 28/35 123/123 [==============================] - 1s 8ms/step - loss: 0.7218 - acc:
> 0.9675 - val_loss: 0.6103 - val_acc: 1.0000 Epoch 29/35 123/123 [==============================] - 1s 7ms/step - loss: 0.6496 - acc:
> 0.9756 - val_loss: 0.5566 - val_acc: 1.0000 Epoch 30/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5961 - acc:
> 0.9675 - val_loss: 0.5152 - val_acc: 1.0000 Epoch 31/35 123/123 [==============================] - 1s 7ms/step - loss: 0.5590 - acc:
> 0.9675 - val_loss: 0.4832 - val_acc: 1.0000 Epoch 32/35 123/123 [==============================] - 1s 9ms/step - loss: 0.5188 - acc:
> 0.9593 - val_loss: 0.4564 - val_acc: 1.0000 Epoch 33/35 123/123 [==============================] - 1s 10ms/step - loss: 0.4987 - acc:
> 0.9675 - val_loss: 0.4355 - val_acc: 1.0000 Epoch 34/35 123/123 [==============================] - 1s 7ms/step - loss: 0.4634 - acc:
> 0.9919 - val_loss: 0.4177 - val_acc: 1.0000 Epoch 35/35 123/123 [==============================] - 1s 7ms/step - loss: 0.4372 - acc:
> 1.0000 - val_loss: 0.4043 - val_acc: 1.0000 31/31 [==============================] - 0s 1ms/step Accuracy: 100.000000
我没有得到任何线索我哪里出错了,即使是为了避免我硬编码句子标记来预测。我想至少通过训练和验证损失图来了解它是否过度拟合,我认为不是。
如果您需要更多信息,请告诉我。
- 时代:35
- 批量:20
- 随机播放 = 假
- 验证拆分为 20%
提前致谢。
