我应该如何根据误差按比例惩罚模型?

人工智能 机器学习 分类 目标函数 交叉熵 分类交叉熵
2021-10-21 11:21:09

我正在制作一个 MNIST 分类器。我使用分类交叉熵作为我的损失函数。我想这样做,如果正确的标签是 3,那么如果它对 4 进行分类而不是对 7 进行分类,它将对模型的惩罚更轻,因为 4 在数字上比 7 更接近 3。我该怎么做呢?

1个回答

我想这样做,如果正确的标签是 3,那么如果它对 4 进行分类而不是对 7 进行分类,它将对模型的惩罚更轻,因为 4 在数字上比 7 更接近 3。我该怎么做呢?

真的你不应该,因为使用的符号(阿拉伯数字)与数量没有直接关系,例如计数或点。它们是很好的分类候选者,尽管在阅读它们时通常映射到数量,但符号本身不是回归的候选者,因为例如符号34不要以任何直观的方式捕捉数量的方式有所不同。

但是,如果您热衷于这样做,在大多数自动微分框架中构建合适的损失函数相对简单。您将需要阅读如何执行此操作。例如,这是一个 Stack Overflow 答案,解释了从哪里开始在 Keras 中编写自定义损失函数

为了让您的损失函数发挥作用,它需要随着预测变得更好而可微分且平滑地变化。这排除了argmax对当前预测使用任何形式的 。如果你想在最后一层坚持使用 softmax,那么我建议对预期的预测使用均方误差,例如,如果di例如是数字iyi,j是表示为 one-hot 向量的基本事实,其中i是例子和j是数字类,那么y^i,j是您的模型预测的概率。你可以使用d^i=j=09jy^i,j对于期望值和 MSE 损失L(di,d^i)=12(d^idi)2

您还可以使用 MSE 损失和交叉熵损失的加权和作为最终损失,这两种损失之间的平衡是模型的新超参数。

请注意,此解决方案使0相近1但远离9. 如果您希望数字在循环中被视为接近(例如8更接近1而不是4) 你需要一些更有创意的东西。

虽然我认为这不会帮助您发现 MNIST 分类的任何改进,但结合两个或多个损失函数来实现更复杂的目标有时会非常有用,因此这是一项值得练习的技能。