RNN 与 LSTM 的原理详解
原文地址:https://blog.csdn.net/happyrocking/article/details/83657993
什么是序列呢?序列是一串有顺序的数据,比如某一条数据为 [x1,x2,x3,x4][x1,x2,x3,x4] [x_1, x_2, x_3, x_4][x1,x2,x3,x4],其中每个元素可以是一个字符、一个单词、一个向量,甚至是一个声音。比如:
语音处理。此时,每个元素是每帧的声音信号。
时间序列问题。例如每天的股票价格等。
RNN 的结构
我们从基础的神经网络中知道,神经网络包含输入层、隐层、输出层,通过激活函数控制输出,层与层之间通过权值连接。激活函数是事先确定好的,那么神经网络模型通过训练“学“到的东西就蕴含在“权值“中。单层的神经网络如图:
基础的神经网络只在层与层之间建立了权连接,RNN最大的不同之处就是在层之间的神经元之间也建立的权连接。如图。
在展开结构中我们可以观察到,在标准的RNN结构中,隐层的神经元之间也是带有权值的。也就是说,随着序列的不断推进,前面的隐层将会影响后面的隐层。图中O代表输出,y代表样本给出的确定值,L代表损失函数,我们可以看到,“损失“也是随着序列的推荐而不断积累的。
除上述特点之外,标准RNN的还有以下特点:
每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。
然而在实际中这一种结构并不能解决所有问题,常见的变种有:
1、多输入单输出
有的时候,我们要处理的问题输入是一个序列,输出是一个单独的值而不是序列,应该怎样建模呢?实际上,我们只在最后一个h上进行输出变换就可以了:
2、单输入多输出
输入不是序列而输出为序列的情况怎么处理?我们可以只在序列开始进行输入计算,其余只需要隐层状态进行传递。
从类别生成语音或音乐等
实际中,还有另外一种多输入多输出的结构,其输入与输出并不是一一对应的,如图:
Encoder-Decoder结构先将输入数据编码成一个上下文向量c。得到c有多种方式,最简单的方法就是把Encoder的最后一个隐状态赋值给c,还可以对最后的隐状态做一个变换得到c,也可以对所有的隐状态做变换。
拿到c之后,就用另一个RNN网络对其进行解码,这部分RNN网络被称为Decoder。具体做法就是将c当做之前的初始状态h0输入到Decoder中。
还有另外一种 Decoder,是将c当做每一步的输入:
文本摘要。输入是一段文本序列,输出是这段文本序列的摘要序列。
阅读理解。将输入的文章和问题分别编码,再对其进行解码得到问题的答案。
语音识别。输入是语音信号序列,输出是文字序列。
下面对多输入多输出(一一对应)的经典结构作分析:
前向传播算法其实非常简单,对于t时刻,隐层单元为:
h(t)=f(Ux(t)+Wh(t−1)+b)h(t)=f(Ux(t)+Wh(t−1)+b) h^{(t)}=f(Ux^{(t)}+Wh^{(t-1)}+b)h(t)=f(Ux(t)+Wh(t−1)+b)
其中,f 为激活函数,如 sigmoid、tanh 等,b 为偏置。
t时刻的输出为:
o(t)=Vh(t)+co(t)=Vh(t)+c o^{(t)}=Vh^{(t)}+co(t)=Vh(t)+c
RNN的训练方法
BPTT(back-propagation through time)算法是常用的训练RNN的方法,其实本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。综上所述,BPTT算法本质还是BP算法,BP算法本质还是梯度下降法,那么求各个参数的梯度便成了此算法的核心。
∂L(t)∂V=∂L(t)∂o(t)∂o(t)∂V∂L(t)∂V=∂L(t)∂o(t)∂o(t)∂V \frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}∂V∂L(t)=∂o(t)∂L(t)∂V∂o(t)
RNN的损失也是会随着时间累加的,所以需要求出所有时刻的偏导然后求和:
L=∑nt=1L(t)L=∑t=1nL(t) L=\sum_{t=1}^n L^{(t)}L=t=1∑nL(t)
∂L∂V=∑nt=1∂L(t)∂o(t)∂o(t)∂V∂L∂V=∑t=1n∂L(t)∂o(t)∂o(t)∂V \frac{\partial L}{\partial V}=\sum_{t=1}^n\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}∂V∂L=t=1∑n∂o(t)∂L(t)∂V∂o(t)
W和U的偏导的求解由于需要涉及到历史数据,其偏导求起来相对复杂,我们先假设只有三个时刻,那么在第三个时刻 L对W的偏导数为:
∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂W+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂W \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂W∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂W∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂W∂h(1)
同理,对U的偏导为:
∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂U∂L(3)∂W=∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂U+∂L(3)∂o(3)∂o(3)∂h(3)∂h(3)∂h(2)∂h(2)∂h(1)∂h(1)∂U \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U}∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂U∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂U∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂U∂h(1)
可以看到,在某个时刻的对W或是U的偏导数,需要追溯这个时刻之前所有时刻的信息,这还仅仅是一个时刻的偏导数,上面说过损失也是会累加的,那么整个损失函数对W和U的偏导数将会非常繁琐。虽然如此但好在规律还是有迹可循,我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式:
∂L(t)∂W=∂L(t)∂o(t)∂o(t)∂h(t)∑tk=1(∏ti=k+1∂h(i)∂h(i−1))∂h(k)∂W∂L(t)∂W=∂L(t)∂o(t)∂o(t)∂h(t)∑k=1t(∏i=k+1t∂h(i)∂h(i−1))∂h(k)∂W \frac{\partial L^{(t)}}{\partial W}= \frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial h^{(t)}}\sum_{k=1}^t(\prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}})\frac{\partial h^{(k)}}{\partial W}∂W∂L(t)=∂o(t)∂L(t)∂h(t)∂o(t)k=1∑t(i=k+1∏t∂h(i−1)∂h(i))∂W∂h(k)
∂L(t)∂U=∂L(t)∂o(t)∂o(t)∂h(t)∑tk=1(∏ti=k+1∂h(i)∂h(i−1))∂h(k)∂U∂L(t)∂U=∂L(t)∂o(t)∂o(t)∂h(t)∑k=1t(∏i=k+1t∂h(i)∂h(i−1))∂h(k)∂U \frac{\partial L^{(t)}}{\partial U}= \frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial h^{(t)}}\sum_{k=1}^t(\prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}})\frac{\partial h^{(k)}}{\partial U}∂U∂L(t)=∂o(t)∂L(t)∂h(t)∂o(t)k=1∑t(i=k+1∏t∂h(i−1)∂h(i))∂U∂h(k)
与V相同,对W或U的整体偏导,也是将所有t时刻的偏导相加。
前面说过激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:
∏ti=k+1∂h(i)∂h(i−1)=∏ti=k+1tanh′⋅W∏i=k+1t∂h(i)∂h(i−1)=∏i=k+1ttanh′⋅W \prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}}=\prod_{i=k+1}^{t}tanh'·Wi=k+1∏t∂h(i−1)∂h(i)=i=k+1∏ttanh′⋅W
或
∏ti=k+1∂h(i)∂h(i−1)=∏ti=k+1sigmoid′⋅W∏i=k+1t∂h(i)∂h(i−1)=∏i=k+1tsigmoid′⋅W \prod_{i=k+1}^{t}\frac{\partial h^{(i)}}{\partial h^{(i-1)}}=\prod_{i=k+1}^{t}sigmoid'·Wi=k+1∏t∂h(i−1)∂h(i)=i=k+1∏tsigmoid′⋅W
我们会发现累乘会导致激活函数导数和权重矩阵的累乘,进而会导致“梯度消失“和“梯度爆炸“现象的发生。
为什么会出现“梯度消失“?我们先来看看这两个激活函数的图像:
可见,sigmoid函数的导数范围是(0,0.25],tach函数的导数范围是(0,1],他们的导数最大都不大于1。累乘的过程中,如果取sigmoid函数作为激活函数的话,那么必然是一堆小数在做乘法,结果就是越乘越小。随着时间序列的不断深入,小数的累乘就会导致梯度越来越小直到接近于0,这就是“梯度消失“现象。其实RNN的时间序列与深层神经网络很像,在较为深层的神经网络中使用sigmoid函数做激活函数也会导致反向传播时梯度消失,梯度消失就意味消失那一层的参数再也不更新,那么那一层隐层就变成了单纯的映射层,毫无意义了,所以在深层神经网络中,有时候多加神经元数量可能会比多家深度好。
同理,由于权重矩阵的累乘,可能会导致“梯度爆炸”的发生。
RNN的特点本来就是能“追根溯源“利用历史数据,现在告诉我可利用的历史数据竟然是有限的,这就令人非常难受,解决“梯度消失“是非常必要的。解决“梯度消失“的方法主要有:
改变传播结构
关于第二点,LSTM结构可以解决这个问题。
总结一下,sigmoid函数的缺点:
sigmoid函数不是0中心对称,tanh函数是,可以使网络收敛的更好。
下面来了解一下LSTM(long short-term memory)。长短期记忆网络是RNN的一种变体,RNN由于梯度消失的原因只能有短期记忆,LSTM网络通过精妙的门控制将短期记忆与长期记忆结合起来,并且一定程度上解决了梯度消失的问题。
长期依赖(Long-Term Dependencies)问题
RNN 的关键点之一就是他们可以用来连接先前的信息到当前的任务上,例如使用过去的视频段来推测对当前段的理解。如果 RNN 可以做到这个,他们就变得非常有用。但是真的可以么?答案是,还有很多依赖因素。
有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果我们试着预测 “the clouds are in the sky” 最后的词,我们并不需要任何其他的上下文 —— 因此下一个词很显然就应该是 sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。
不幸的是,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。
LSTM网络
Long Short Term 网络—— 一般就叫做 LSTM ——是一种 RNN 特殊的类型,可以学习长期依赖信息。LSTM 由Hochreiter & Schmidhuber (1997)提出,并在近期被Alex Graves进行了改良和推广。在很多问题,LSTM 都取得相当巨大的成功,并得到了广泛的使用。
LSTM 通过刻意的设计来避免长期依赖问题。记住长期的信息在实践中是 LSTM 的默认行为,而非需要付出很大代价才能获得的能力!
所有 RNN 都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh 层。
黄色的矩形是学习得到的神经网络层
粉色的圆形表示一些运算操作,诸如加法乘法
黑色的单箭头表示向量的传输
两个箭头合成一个表示向量的连接
一个箭头分开表示向量的复制
细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。
LSTM 拥有三个门,来保护和控制细胞状态。
理解LSTM的三个门
遗忘门
在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为遗忘门完成。该门会读取ht−1ht−1 h_{t-1}ht−1和xtxt x_txt,输出一个在 0 到 1 之间的数值给每个在细胞状态Ct−1Ct−1 C_{t-1}Ct−1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。
让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。
对于第一个问题,“遗忘“可以理解为“之前的内容记住多少“,其精髓在于只能输出(0,1)小数的sigmoid函数和粉色圆圈的乘法,LSTM网络经过学习决定让网络记住以前百分之多少的内容。对于第二个问题就更好理解,决定记住什么遗忘什么,其中新的输入肯定要产生影响。
输入门
下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分。第一,sigmoid 层称 “输入门层” 决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量,Ct˜Ct~ \tilde{C_{t}}Ct~,会被加入到状态中。下一步,我们会讲这两个信息来产生对状态的更新。
在我们语言模型的例子中,我们希望增加新的主语的性别到细胞状态中,来替代旧的需要忘记的主语。
我们把旧状态与ftft f_tft相乘,丢弃掉我们确定需要丢弃的信息。接着加上it⋅Ct˜it⋅Ct~ i_t·\tilde{C_t}it⋅Ct~。这就是新的候选值,根据我们决定更新每个状态的程度进行变化。
有了上面的理解基础输入门,输入门理解起来就简单多了。tanh函数创建新的输入值,sigmoid函数决定可以输入进去的比例。
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。
LSTM的变体
我们到目前为止都还在介绍正常的 LSTM。但是不是所有的 LSTM 都长成一个样子的。实际上,几乎所有包含 LSTM 的论文都采用了微小的变体。差异非常小,但是也值得拿出来讲一下。
窥视孔连接
一个流行的LSTM变种,由Gers & Schmidhuber (2002提出,加入了“窥视孔连接(peephole connection)”。也就是说我们让各种门可以接受到细胞状态的输入。
对偶忘记门和输入门
另一个变体是通过使用对偶(coupled)忘记门和输入门。不同于之前是分开确定什么忘记和需要添加什么新的信息,这里是一同做出决定。我们仅仅会当我们将要输入在当前位置时忘记。我们仅仅输入新的值到那些我们已经忘记旧的信息的那些状态 。
另一个改动较大的变体是 Gated Recurrent Unit (GRU),这是由 Cho, et al. (2014) 提出。它将忘记门和输入门合成了一个单一的 更新门。同样还混合了细胞状态和隐藏状态,和其他一些改动。最终的模型比标准的 LSTM 模型要简单,也是非常流行的变体。
要问哪个变体是最好的?其中的差异性真的重要吗?Greff, et al. (2015) 给出了流行变体的比较,结论是他们基本上是一样的。Jozefowicz, et al. (2015) 则在超过 1 万种 RNN 架构上进行了测试,发现一些架构在某些任务上也取得了比 LSTM 更好的结果。
————————————————
版权声明:本文为CSDN博主「HappyRocking」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/happyrocking/article/details/83657993