多标签分类 - 过度拟合?

数据挖掘 神经网络 深度学习 多标签分类 过拟合 概括
2022-02-15 15:24:49

我的任务如下:

输入药物组合并从药物组合中输出肾功能衰竭相关症状。

药物组合和肾功能衰竭相关症状都表示为 one-hot 编码(例如,某人在总共 4 种症状中出现症状 1 和症状 3 表示为[1,0,1,0])。

到目前为止,我已经通过以下模型运行了数据,并生成了这个有趣的图表。左图描述了模型在 epoch 上的训练和验证损失,右图描述了模型在 epoch 上的训练和验证准确度。

历代模型性能的损失和准确度概况

从左图中几乎可以清楚地看出,在基线模型中,随着训练损失的减少和验证损失的增加,随着时间的推移,出现了过度拟合。然而,图表上的准确度图表明,尽管验证准确度有所提高,但验证准确度仍在继续提高。

加入 dropout 和 L2 正则化(Baseline Reg)后,validation loss 并没有上升那么多,这似乎解决了过拟合问题,但准确率非常零星,平均比 Baseline 模型差。

问题如下:我的直觉是否正确,因为两个模型的结果都显示过拟合,我应该继续努力减少这种影响吗?

提前致谢!

模型架构如下:

Model: "baseline"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_6 (Dense)              (884, 800)                603200    
_________________________________________________________________
dense_7 (Dense)              (884, 800)                640800    
_________________________________________________________________
dense_8 (Dense)              (884, 4)                  3204      
=================================================================
Total params: 1,247,204
Trainable params: 1,247,204
Non-trainable params: 0
_________________________________________________________________

Model: "baseline_reg"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              (884, 800)                603200    
_________________________________________________________________
dropout (Dropout)            (884, 800)                0         
_________________________________________________________________
dense_10 (Dense)             (884, 800)                640800    
_________________________________________________________________
dense_11 (Dense)             (884, 4)                  3204      
=================================================================
Total params: 1,247,204
Trainable params: 1,247,204
Non-trainable params: 0
_________________________________________________________________
1个回答

准确性不是评估多标签分类的非常有用的方法,这可能会导致结果不一致。鉴于域的重要性,完全匹配可能是更合适的评估指标。