用于BCE具有范围的输出[ 0 , 1 ] 不正确。
如您所知,二元交叉熵的损失计算为:
−(ylog(p)+(1−y)log(1−p))
如果 y=1,公式的第一部分将被激活,如果 y=0第二部分将被激活。但是,想象一下,如果我们没有完全y 作为 0 或者 1,但它们之间的任何数字。该公式仍然有效,并返回对真实标签的预测损失。在这种情况下,不同之处在于它不限于激活−ylog(p) 或者 −(1−y)log(1−p). 它们都将被部分激活。
PS:当然你也可以mean_squared_error作为自动编码器的损失函数。
更新:我认为使用哪一个没有优势。这取决于您的目标,在某些情况下,其中之一可能更适合您的需求。只有一件事可能很重要,BCE 比 MSE 返回更高的损失,并且在您想要更多地惩罚错误的情况下,首选 BCE。我将比较他们两者的结果与随机数据的损失。
# binary cross entropy
def bce(y_t,y_p):
epsilon = 1e-4
return -(y_t*np.log(y_p+epsilon)+(1-y_t)*(np.log(1-y_p+epsilon)))
# mean squared error
def mse(y_t,y_p):
return (y_p-y_t)**2
# random labels and logits
y_t_array = tf.random.uniform((1,10),minval=0,maxval=1).numpy()
y_p_array = tf.random.uniform((1,10),minval=0,maxval=1).numpy()
# loss for each pair of above arrays
loss_mse_array = [mse(i,j) for i,j in zip(y_t_array,y_p_array)]
loss_bce_array = [bce(i,j) for i,j in zip(y_t_array,y_p_array)]
# plot the losses for a better comparison
import matplotlib.pyplot as plt
plt.plot(range(len(loss_mse_array[0])) , loss_mse_array[0], 'bo')
plt.plot(range(len(loss_bce_array[0])) , loss_bce_array[0], 'b+')
plt.legend(["mse","bce"], loc="upper right")
plt.show()
