如何画一个简单的 LSTM 网络

数据挖掘 深度学习 lstm matlab
2021-10-14 16:32:10

我是深度学习的新手,我正在为我的博士工作学习 LSTM。这是一个用于序列分类的简单 LSTM 网络。此代码来自 MATLAB 教程: layers = [sequenceInputLayer(1) lstmLayer(5,'OutputMode','last') fullyConnectedLayer(3) softmaxLayer classificationLayer];

为简单起见,输入序列的维度为1,有3类。

我正在尝试绘制此网络的图表。这是我的尝试: 在此处输入图像描述

这是正确的吗?应该连接 LSTM 蓝色单元吗?橙色单元是 softmax 层,每个单元上是否应该有任何符号(如 ∑)?每一层都有同样的问题?应该有任何额外的层来表示“ classificationLayer”吗?隐含在最后一层的fullyConnectedLayer全连接中,我需要为此添加任何额外的层吗?请问还有什么意见吗?

1个回答

在考虑 rnn/lstm/gru 层时,请记住几点。

  1. 你的输入大小是多少?在这种情况下,我们有 5 个单词的句子,因此 5 个输入圆进入 lstm 层,每个单词的值都将乘以相同的权重值。
  2. lstm 层内将存在多少个 lstm 单元?为了简单起见,因为我们有 5 个单词,我们为每个单词保留 5 个 lstm 单元,然后将其内存转发到下一个 lstm 单元中。
  3. 我们有多少输出类别?最后一个 lstm 单元将连接到所有 3 个密集层神经元,并且在每个神经元上,softmax 操作将像正常的全连接神经网络密集层工作一样发生。

基于网络描述 lstm network