根据前面的 M 和后面的 N 个元素预测序列元素

数据挖掘 喀拉斯 lstm 火炬 顺序 深度学习
2022-02-24 16:05:03

我有一个长度相等的序列数组,每个序列包含 300 个数字(M=300)。序列中的每个元素都是从 1 到 9 的数字:

13571398...2455 # 300 numbers
33344467...1143 # 300 numbers
...
...
...
66118859...2121 # 300 numbers

我的任务是建立一个模型,根据序列中的前 179 个元素和最后 110 个元素,预测序列位置从 180 到 190 的元素(数字)。换句话说,给定位置从 0 到 179 和从 191 到 299 的元素预测序列中位置从 180 到 190 的元素。

我正在考虑使用 Keras BiLSTM 模型解决此任务的以下步骤:

  • 将所有序列拆分为训练/验证/测试集
  • 在训练集上训练 BiLSTM 以预测序列中任意位置的下一个数字
  • 在测试和验证集中,将 180 到 190 位置的 K 个元素随机替换为 0(原始序列中不存在的数字)。
  • 使用预训练的 BiLSTM 预测验证和测试集中“0”元素的真实值

请帮助解决以下问题:

  • 在这种情况下,我应该如何表示 BiLSTM 的数据和类?看起来我的数据和类是一回事。1...9 两个数字都是 BiLSTM 的数据和对应的类。
  • 在这种情况下,我应该创建哪些数据结构和编码来使用 Keras BiLSTM 进行训练和预测?
  • 如何在训练集和测试集上评估该模型的质量?

非常欢迎使用其他模型的任何其他想法,特别是变形金刚(PyTorch,Tesnsorflow),谢谢!

1个回答

您的问题的框架接近于所谓的语言建模任务。因为您的输入数据是固定长度的样本,您可以使用具有固定大小上下文嵌入的 seq2seq 模型。

这意味着您本质上将有一个编码器,例如 Bi-LSTM,它将您的输入编码为一个固定的表示(通过连接前向和后向 LSTM 的最终输出状态)和一个解码器,例如 LSTM,它顺序生成输出令牌。

您的目标函数可以是每个输出令牌的交叉熵损失的平均值,也可以是更复杂的损失,例如 CTC。您还可以通过仅预测掩码标记而不是整个句子作为神经网络的输出来简化它。

您的标记是整数这一事实没有任何区别,实际上简化了嵌入。您可以简单地将数据按原样提供给 Keras 或 PyTorch 中的嵌入层。如果你使用 PyTorch,我会推荐这个教程,使用转换器而不是 LSTM。