Keras:稀疏分类交叉熵的维度错误

数据挖掘 喀拉斯
2022-03-07 05:46:56

我正在尝试制作一个 NN,根据时钟上的时间,它会尝试预测哪个类(在此示例中为 32 个)正在向系统发出请求。作为第一次尝试,我尝试使用categorical_crossentropy,但这显然行不通,因为目标非常稀疏,因此系统将通过始终预测非请求来获得丰厚回报。

现在我正在尝试使用sparse_categorical_crossentropy,但我不断收到尺寸不匹配错误(在这种情况下,训练集和测试集是相同的,因为我最初只是想评估训练集中的性能):

Error when checking target: expected dense_90 to have shape (1,) but got array with shape (32,)

DataFrame这里(一个简单的时钟和请求的另一列),代码是

from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
import tensorflow as tf

requests = df['requests'].values
requests_cat = to_categorical(requests, 32)

length = len(df['clock'])
train = np.reshape(df['clock'].values, (length, 1))
train = train.astype(np.int)
target = requests_cat

model = Sequential()
model.add(Dense(25, activation = 'relu', input_shape = (train.shape[1],)))
model.add(Dense(25, activation = 'relu'))
model.add(Dense(32, activation = 'softmax'))

model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy')
model.fit(x = train, y = target, epochs = 100, validation_data = (train, target))

在旁注中:

  1. 在这种情况下,这种架构似乎不是最好的。作为第二个原型,我正在考虑使用 LSTM 做一些事情,因为过去的请求会影响以后的请求。是否有用于调度的标准架构?
  2. 将稀疏集拆分为训练集和测试集的正确方法是什么?
1个回答

我觉得你误解了和之间的categorical_crossentropy区别sparse_categorical_crossentropy稀疏部分不是指数据的稀疏性,而是指标签的格式

  • 如果您的标签是一次性编码的:使用categorical_crossentropy

  • 如果您的标签被编码为整数:使用sparse_categorical_crossentropy