如果我对属于某个类别的样本的概率感兴趣,交叉熵是一个很好的成本函数吗?

数据挖掘 机器学习 神经网络 分类 张量流 成本函数
2022-02-25 12:07:30

我正在训练一个神经网络,对于六个类别中的每一个,都试图预测样本属于它的概率。之后,我想将这些概率用作属于该类的样本的一部分。我的网络提供 softmax 输出并使用交叉熵成本进行训练(实际上是线性输出,然后由 tf.nn.softmax_cross_entropy_with_logits 转换为 softmax)

当我想训练网络使所有 6 个概率都正确而不是仅仅将每个样本分类为 6 个类之一时,这是正确的成本函数吗?我开始犹豫了,因为 tensorflow 文档中提到了 tf.nn.softmax_cross_entropy_with_logits:

测量类别互斥(每个条目恰好属于一个类别)的离散分类任务中的概率误差。

更新 尽管听起来很奇怪,但交叉熵成本函数似乎在这里效果最好。在我的理解中,这是因为离散概率分布(维基百科)的交叉熵是

H(p,q)=xp(x)logq(x).

这个函数是最小的p(x)=q(x)对全部x. 这解释了为什么最小化交叉熵误差会强制输出分布q到目标分布p,即使我滥用 tf.nn.softmax_cross_entropy_with_logits 而不是给它标签,我给它一个离散的概率分布。

为什么它比 MSE 工作得更好可能是因为 6 类概率的大小非常不同。第 1 类可能有 80% 的概率,而第 2 类有 0.5% 的概率,因此 MSE 更注重让 80% 的类正确。

这是否意味着 CE 仍然是最佳选择,还是有办法缩放输出或权衡它们,以便网络注意让每个类都正确?

1个回答

交叉熵确实适用于多类分类。

当 tensorflow 文档说明 cross_entropy 时:

“测量类别互斥的离散分类任务中的概率误差(每个条目都在一个类别中)。”

它指的是在概念层面上,类预计不会重叠。非重叠类的一个示例是“cat”和“dog”。重叠类的一个例子是“Elephant”和“African Elephant”;这种类型的问题称为多标签分类,因为你有标签而不是类,每个人都可以分配许多标签。

更新:使用有关问题的新信息,我们可以看出所面临的问题不是分类问题,因为所需的输出是特定的概率值。这意味着问题是回归问题。由于输出是概率,因此使用 softmax 来确保它们加起来为 1 是合适的。因为它是回归,所以使用 MSE 作为损失函数是合适的。