学习笔记——RNN相关知识

从RNN基本结构开始

RNN(Recurrent Neural Networks,循环神经网络)相较于传统神经网络,允许了信息的持久化,即神经元的输出可以在下一个时间戳直接作用到自身。

传统神经网络对下面两种情况无法很好的处理

  • 输入和输出的长度(维度)在不同的例子中不同(句子的长度不同)
  • 在输入的不同维度上学到的特征无法共享(比如文本的不同位置上)

RNN的链式结构决定了它适用于时序相关问题,更能够很好的应对上面的两种情况 Alt text

标准RNN神经元结构 Alt text

RNN的前向传播(Forward Propagation)

Alt text RNN比较常见的N对N结构图,后面的计算都根据这张图的结构来进行。

假设\(x^{(t)}\)\(t\)时刻的输入,\(h^{(t)}\)\(t\)时刻隐藏层输出,\(\hat y^{(t)}\)\(t\)时刻输出层输出,则 \[\begin{align*} &h^{(t)} = tanh(Ux^{(t)} + Wh^{(t-1)} + b) \\ &o^{(t)} = Vh^{(t)} + c \\ &\hat y^{(t)} = \sigma (o^{(t)}) \end{align*}\] 其中,\(\sigma\)是输出层的激活函数(一般根据输出类型选择sigmoid/softmax等),隐藏层的激活函数通常都使用tanh。

明确一下维度:
假设 \(x^{(t)}.shape=(300,1)\)\(h^{(t)}.shape=(100,1)\),则有 \[\begin{align*} &U.shape = (100,300)\\ &W.shape = (100,100)\\ \end{align*}\]

RNN的反向传播(Back Propagation)

RNN的后向传播也叫做BPTT(back-propagation through time)。

继续上面的计算,假设\(L^{(t)}\)\(t\)时刻的损失函数,则总损失\(L=\sum_{t=1}^TL^{(t)}\),令\(\delta^{(t)} = \frac{\partial L}{\partial h^{(t)}}\) \[\begin{align*} &\frac{\partial L}{\partial c} = \sum_{t=1}^{T}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial c}= \sum_{t=1}^{T}\frac{\partial L^{(t)}}{\partial o^{(t)}}\\ &\frac{\partial L}{\partial V} = \sum_{t=1}^{T}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial V}=\sum_{t=1}^{T}\frac{\partial L^{(t)}}{\partial o^{(t)}}(h^{(t)})^T\\ &\frac{\partial L}{\partial W} = \sum_{t=1}^{T}\frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial W}=\sum_{t=1}^{T}diag(1-(h^{(t-1)})^2)\delta^{(t)}(h^{(t-1)})^T\\ &\frac{\partial L}{\partial b}=\sum_{t=1}^{T}\frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial b}=\sum_{t=1}^{T}diag(1-(h^{(t-1)})^2)\delta^{(t)}\\ &\frac{\partial L}{\partial U}=\sum_{t=1}^{T}\frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial U}=\sum_{t=1}^{T}diag(1-(h^{(t-1)})^2)\delta^{(t)}(x^{(t)})^T\\ \end{align*}\]

其中,\(\frac{\partial L^{(t)}}{\partial o^{(t)}}\)根据输出层激活函数\(\sigma\)来进行计算,而\(\delta^{(t)} = \frac{\partial L}{\partial h^{(t)}}\)的计算比较复杂,会牵涉到\(h^{(t)}\)的下一状态\(h^{(t+1)}\),计算时需要从\(\delta^{(T)}\)开始从后向前计算 \[\begin{align*} &\delta^{(t)} =\frac{\partial L}{\partial h^{(t)}} = \frac{\partial L}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}} + \frac{\partial L}{\partial h^{(t + 1)}}\frac{\partial h^{(t + 1)}}{\partial h^{(t)}}=\sum_{t=1}^{T}V^T\frac{\partial L^{(t)}}{\partial o^{(t)}}+W^T\delta^{(t+1)}diag(1-(h^{(t+1)})^2)\\ &\delta^{(T)}=\sum_{t=1}^{T}V^T\frac{\partial L^{(t)}}{\partial o^{(t)}} \end{align*}\] 具体计算过程中要考虑求导过程中矩阵转置等问题,比较复杂。Reference里的第5篇的正向反向传播公式推导非常细致,另外还包括LSTM的公式推导。

RNN的几种常见结构

Many-to-many结构

Alt text 每一个时刻的输入都对应一个输出

  • 词性标注
  • Char RNN(输入为字符,输出为下一个字符的概率)

Many-to-one结构

Alt text 输入序列,输出单个值

  • 根据一段文字进行情感分类
  • 一段音频/视频的分类

One-to-many结构

Alt text Alt text 输入为单个值,输出序列 有两种常见方式,一种是只在第一个时刻输入X,一种是在每个时刻都输入同一个X

  • 从图像生成文字(image caption)
  • 从类别生成音乐

Encoder-Decoder结构

Alt text Alt text 又叫Seq2Seq模型,通过编码器部分获得c,解码器部分再通过c获得输出。有两种方式,一种是c只在后半部分网络的开头输入,一种是在后半部分网络的每个时刻都输入

  • 机器翻译
  • 文本摘要
  • 阅读理解
  • 语音识别

图片来自https://zhuanlan.zhihu.com/p/28054589

双向RNN

标准的RNN在t时刻的输出只能根据t时刻及t时刻之前的输入来决定,但是在很多任务中,t时刻的输出很可能也和t时刻之后的输入有关。

例如这个两个句子中的人名判断:
- Teddy Roosevelt was a great president.
- Teddy bears are on sale!

双向RNN(Bidirectional RNN, Bi-RNN, BRNN)就是为了解决这个问题而创造的。双向RNN相当于正向传播计算一次\(h^{(t)}\),再计算一次反向传播回去的\(g^{(t)}\),最后的输出\(\hat y^{(t)}\)同时取决于\(h^{(t)}\)\(g^{(t)}\)。相当于每个训练序列向前和向后分别训练一个RNN再接到同一个输出层上去。

双向RNN结构图 Alt text \(o^{(t)} = V[h^{(t)}, g^{(t)}] + c\)

深度RNN

Alt text 给RNN增加深度也就是使模型变得更强大,相当于增加网络复杂度来更贴合训练数据。 循环神经网络可以通过许多方式变得更深:

  • 隐藏循环状态可以被分解为具有层次的组。
  • 可以向输入到隐藏,隐藏到隐藏以及隐藏到输出的部分引入更深的计算 (如 MLP(多层感知器))。这可以延长链接不同时间步的最短路径。
  • 可以引入跳跃连接来缓解路径延长的效应。

一些关于RNN的思考

1. RNN的梯度爆炸和梯度消失

梯度的计算过程中,会需要多个激活函数偏导相乘,如果这些激活函数的偏导较小(小于1),那么就很可能会造成梯度消失;相反,如果这些激活函数的偏导较大(大于1),那么就很可能会造成梯度爆炸。

见上文反向传播中的计算,梯度消失/爆炸的主要原因就是\(\delta^{(t)}\),在计算偏导的时候需要\(T\)\(W\)连乘,当\(W\)的值较小时就会引起梯度消失,较大时就会引起梯度爆炸。

梯度消失就类似于随时间推移,越早的东西越记不住了。因为越早的\(h\)\(W\)相乘了越多次,在\(W\)小于1的情况下,这些早前的\(h\)几乎已经无法对现在造成影响了。

2. RNN中梯度爆炸和消失的解决方法

梯度爆炸(gradient explosion)不是个严重的问题,一般采用梯度裁剪(gradient clipping)来处理,即大于某个阈值的时候,对梯度向量进行整体的缩放。

梯度消失(gradient vanishing)就比较麻烦了,虽然也可以用梯度裁剪来处理,即梯度小于某个阈值的时候,更新的梯度为这个阈值,但是这个方法很难找到一个满意的阈值。(我猜:设置小了的话就下降很慢,设置大了呢又很难收敛到最优。)

  • 有效初始化+ReLU激活函数能够得到较好效果(?)
  • 算法上的优化,例如截断的BPTT算法。
  • 模型上的改进,例如LSTM、GRU单元都可以有效解决长期依赖问题。
  • 在BPTT算法中加入skip connection,此时误差可以间歇的向前传播。
  • 加入一些Leaky Units,思路类似于skip connection

3. 为什么RNN中多使用tanh而不是ReLU

tanh函数 Alt text

ReLU函数 Alt text

参考上文RNN的反向传播部分,RNN中使用ReLU会由于连乘导致非常大的输出值,造成爆炸。而使用tanh的话就会好很多。(这个问题的答案是有争议的)


LSTM

LSTM(Long Short Term Memory Networks,长短时记忆网络)是一种 RNN 特殊的类型,可以学习长期依赖信息。

RNN在处理长期依赖时会遇到巨大的困难,因为计算距离较远的节点之间的联系时会有梯度消失的问题(时间轴上的梯度消失),为了解决该问题,研究人员提出了许多解决办法,例如ESN(Echo State Network),增加有漏单元(Leaky Units)等等。其中最成功应用最广泛的就是门限RNN(Gated RNN),而LSTM就是门限RNN中最著名的一种。

LSTM神经元结构 Alt text

LSTM引入了细胞状态(Cell State)的概念,就是横穿神经元的传送带,只有少量的线性交互,信息在上面流传不会出现梯度消失的问题。 Alt text

LSTM的门

Alt text 门是一种让信息选择式通过的方法。一个门的结构如上图,包含一个sigmoid神经网络层(使用\(\sigma\)表示)和一个pointwise乘法操作。
Sigmoid层输出0到1之间的数值,描述每个部分有多少量可以通过。0代表“不许任何量通过”,1就指“允许任意量通过”。

Sigmoid函数 Alt text

相较于标准RNN,LSTM拥有三个门,来保护和控制细胞状态。

遗忘门(forget gate)

Alt text 遗忘门通过观察\(h_{t-1}\)\(x_t\),对于细胞状态\(C_{t-1}\)中的每一个元素,输出一个0~1之间的数。1表示“完全保留”,0表示“完全丢弃”。

输入门(input gate)

Alt text 输入门同样是通过观察\(h_{t-1}\)\(x_t\),决定我们要加入哪些新信息到细胞状态中。之后由一个tanh创造一个新的候选值向量。

Alt text 输入门和遗忘门都是用于改变细胞状态的门,将\(C_{t-1}\)变为\(C_{t}\)就需要用到这两个门的输出。
首先通过与遗忘门输出相乘丢弃不再需要的信息,然后再与输入门输出相加得到最新的细胞状态。

输出门(output gate)

Alt text 输出门首先通过观察\(h_{t-1}\)\(x_t\)来确定细胞状态的哪个部分将输出出去。然后把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和输出门的输出相乘,得到最终需要的输出。


GRU

GRU将遗忘门和输入门合成成了更新门,还将细胞状态\(C_t\)和隐藏状态\(h_t\)合并,即每个时刻的细胞状态直接等于\(h_t\) Alt text 图中\(r_t\)为输出门,\(z_t\)为更新门
输出门\(r_t\)决定候选值\(\tilde h_t\)中使用\(H_{t-1}\)中的哪些元素进行计算 更新门\(z_t\)决定细胞状态中的某个元素是否更新为候选值\(\tilde h_t\)中的值,如果不更新就保留为\(h_{t-1}\)中的值


Attention机制

Attention机制的出现主要是由于上文中提到的Encoder-Decoder结构的缺陷,Encoder-Decoder结构中的c必须包含原始系列中的所有信息,因此c的长度就成了限制模型性能的瓶颈。

Attention机制的实现是通过保留RNN编码器对输入序列的中间输出结果,然后训练一个模型来对这些输入进行选择性的学习并且在模型输出时将输出序列与之进行关联。

引入Attention的Encoder-Decoder结构 Alt text


Reference

  1. http://colah.github.io/posts/2015-08-Understanding-LSTMs/
  2. https://www.jianshu.com/p/9dc9f41f0b29
  3. https://blog.csdn.net/xuanyuansen/article/details/61913886
  4. https://zhuanlan.zhihu.com/p/27345523
  5. https://www.cnblogs.com/pinard/p/6509630.html
  6. https://zhuanlan.zhihu.com/p/28054589
  7. http://arunmallya.github.io/writeups/nn/lstm/index.html#/