交叉熵损失解释

数据挖掘 机器学习 神经网络 深度学习 软最大
2021-10-05 19:54:57

假设我建立了一个用于分类的神经网络。最后一层是具有 Softmax 激活的密集层。我有五个不同的类别要分类。假设对于单个训练示例,true label[1 0 0 0 0]而预测是[0.1 0.5 0.1 0.1 0.2]我将如何计算此示例的交叉熵损失?

4个回答

交叉熵公式有两种分布,p(x), 真实分布, 和 q(x),估计分布,定义在离散变量上 x 并且由

H(p,q)=xp(x)log(q(x))

对于神经网络,计算与以下内容无关:

  • 使用了什么样的层。

  • 使用了什么样的激活 - 尽管许多激活与计算不兼容,因为它们的输出不能解释为概率(即它们的输出为负数、大于 1 或总和不为 1)。Softmax 通常用于多类分类,因为它保证了良好的概率分布函数。

对于神经网络,您通常会看到写成以下形式的方程:y是地面实况向量和y^(或直接取自最后一层输出的其他值)是估计值。对于单个示例,它看起来像这样:

L=ylog(y^)

在哪里是内积。

您的示例基本事实y 给出第一个值的所有概率,其他值为零,因此我们可以忽略它们,只使用您估计中的匹配项 y^

L=(1×log(0.1)+0×log(0.5)+...)

L=log(0.1)2.303

评论中的一个重要观点

这意味着,无论预测是否正确,损失都是相同的 [0.1,0.5,0.1,0.1,0.2] 或者 [0.1,0.6,0.1,0.1,0.1]?

是的,这是多类 logloss 的一个关键特性,它只奖励/惩罚正确类的概率。该值与剩余概率在不正确的类之间的分配方式无关。

您经常会看到这个等式在所有示例上平均作为成本函数。它在描述中并不总是严格遵守,但通常损失函数是较低级别的,描述单个实例或组件如何确定错误值,而成本函数是较高级别,描述如何评估完整系统以进行优化。基于多类对数损失的大小数据集的成本函数N 可能看起来像这样:

J=1N(i=1Nyilog(y^i))

许多实现将要求您的基本事实值是一次性编码的(使用单个真实类),因为这允许进行一些额外的优化。然而,原则上,交叉熵损失可以被计算 - 并优化 - 当情况并非如此时。

尼尔的回答是正确的。但是我认为重要的是要指出,虽然损失不取决于不正确类之间的分布(仅取决于正确类和其余类之间的分布),但该损失函数的梯度确实会根据不同的方式影响不正确的类他们错了。因此,当您在机器学习中使用 cross-ent 时,您将为 [0.1 0.5 0.1 0.1 0.2] 和 [0.1 0.6 0.1 0.1 0.1] 更改不同的权重。这是因为正确类别的分数被所有其他类别的分数归一化以将其转化为概率。

让我们从理解信息论中的熵开始:假设你想传达一串字母“aaaaaaaa”。您可以轻松地将其作为 8*"a"。现在取另一个字符串“jteikfqa”。有没有一种压缩的方式来传达这个字符串?那里没有。我们可以说第二个字符串的熵更多,因为为了传达它,我们需要更多的“位”信息。

这个类比也适用于概率。如果您有一组项目,例如水果,这些水果的二进制编码将是log2(n)其中 n 是水果的数量。对于 8 个水果,您需要 3 个位,依此类推。另一种看待这个问题的方法是,假设某人随机选择一种水果的概率是 1/8,如果选择一种水果,则不确定性减少是log2(1/8)即 3。更具体地说,

i=1818log2(18)=3
这个熵告诉我们某些概率分布所涉及的不确定性;概率分布中的不确定性/变化越多,熵就越大(例如,对于 1024 个水果,熵为 10)。

在“交叉”熵中,顾名思义,我们关注解释两种不同概率分布差异所需的位数。最好的情况是两个分布是相同的,在这种情况下需要最少的比特,即简单熵。用数学术语来说,

H(y,y^)=iyiloge(y^i)

在哪里 y^ 是预测的概率向量(Softmax 输出),并且 y是ground-truth向量(例如one-hot)。我们使用自然对数的原因是因为它很容易区分(参考计算梯度),我们不采用地面实况向量对数的原因是因为它包含很多简化求和的 0。

底线:通俗地说,可以将交叉熵视为两个概率分布之间的距离,就解释该距离所需的信息量(比特)而言。这是一种定义损失的巧妙方法,随着概率向量越来越接近,损失会下降。

让我们看看损失的梯度如何表现......我们将交叉熵作为损失函数,由下式给出

H(p,q)=i=1np(xi)log(q(xi))=(p(x1)log(q(x1))++p(xn)log(q(xn))

从这里开始..我们想知道关于一些的导数 xi

xiH(p,q)=xip(xi)log(q(xi)).
由于所有其他术语由于差异而被取消。我们可以将这个等式更进一步
xiH(p,q)=p(xi)1q(xi)q(xi)xi.

从这里我们可以看到,我们仍然只惩罚真正的类(对于 p(xi))。否则,我们只有零梯度。

我确实想知道软件包如何处理预测值 0,而真实值大于零……因为在这种情况下我们除以零。