LSTM的公式推导详解
2015-07-31 11:44
337 查看
导言
在Alex Graves的这篇论文《Supervised Sequence Labelling with Recurrent Neural Networks》中对LSTM进行了综述性的介绍,并对LSTM的Forward Pass和Backward Pass进行了公式推导。这篇文章将用更简洁的图示和公式一步步对Forward和Backward进行推导,相信读者看完之后能对LSTM有更深入的理解。
如果读者对LSTM的由来和原理存在困惑,推荐DarkScope的这篇博客:《RNN以及LSTM的介绍和公式梳理》
一、LSTM的基础结构
LSTM的结构中每个时刻的隐层包含了多个memory blocks(一般我们采用一个block),每个block包含了多个memory cell,每个memory cell包含一个Cell和三个gate,一个基础的结构示例如下图:一个memory cell只能产出一个标量值,一个block能产出一个向量。
二、LSTM的前向传播(Forward Pass)
1. 引入
首先我们在上述LSTM的基础结构之上构造时序结构,这样让读者更清晰地看到Recurrent的结构:这里我们有几个约定:
每个时刻的隐层包含一个block
每个block包含一个memory cell
下面前向传播我们则从Input开始,逐个求解Input Gate、Forget Gate、Cells Gate、Ouput Gate和最终的Output
这里需要申明的一点,推导过程严格按照上述图示LSTM的结构;论文中对相较于该文章的推导过程会有增加一些项,在每一个公式不一致的地方我都会有相应说明。
2. Input Gate(ι\iota) 的计算
Input Gate接受两个输入:当前时刻的Input作为输入:xtx^t
上一时刻同一block内所有Cell作为输入:st−1cs_c^{t-1}
该案例中每层仅有单个Block、单个cemory cell,可以忽略∑Cc=1\sum_{c=1}^{C},以下Forget Gate和Output Gate做相同处理。
最终Input Gate的输出为:
atι=∑i=1Iωiιxti+∑c=1Cωcιst−1c a_\iota^t = \sum_{i=1}^{I} \omega_{i\iota} x_i^t + \sum_{c=1}^{C} \omega_{c\iota} s_c^{t-1}
btι=f(atι) b_\iota^t = f(a_\iota^t)
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hb_h^{t-1}作为输入,论文中atιa_\iota^t会增加一项∑Hh=1ωhιbt−1h\sum_{h=1}^{H} \omega_{h\iota} b_h^{t-1}。
3. Forget Gate(ϕ\phi) 的计算
Forget Gate接受两个输入:当前时刻的Input作为输入:xtx^t
上一时刻同一block内所有Cell作为输入:st−1cs_c^{t-1}
最终Forget Gate的输出为:
atϕ=∑i=1Iωiϕxti+∑c=1Cωcϕst−1c a_\phi^t = \sum_{i=1}^{I} \omega_{i\phi} x_i^t + \sum_{c=1}^{C} \omega_{c\phi} s_c^{t-1}
btϕ=f(atϕ) b_\phi^t = f(a_\phi^t)
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hb_h^{t-1}作为输入,论文中atϕa_\phi^t会增加一项∑Hh=1ωhϕbt−1h\sum_{h=1}^{H} \omega_{h\phi} b_h^{t-1}。
4. Cell(cc) 的计算
Cell的计算稍有些复杂,接受两个输入:Input Gate和Input输入的乘积
Forget Gate和上一时刻对应Cell输出的乘积
最终Cell的输出为:
atc=∑i=1Iωicxti a_c^t = \sum_{i=1}^{I} \omega_{ic} x_i^t
stc=btϕst−1c+btιg(atc) s_c^t = b_\phi^t s_c^{t-1} + b_\iota^t g(a_c^t)
这里Input Gate还可以接受上一个时刻中不同block的输出bt−1hb_h^{t-1}作为输入,论文中atca_c^t会增加一项∑Hh=1ωhcbt−1h\sum_{h=1}^{H} \omega_{hc} b_h^{t-1}。
5. Output Gate(ω\omega) 的计算
Output Gate接受两个输入:当前时刻的Input作为输入:xtx^t
当前时刻同一block内所有Cell作为输入:stcs_c^t
这里Output Gate接受“当前时刻Cell的输出”而不是“上一时刻Cell的输出”,是由于此时Cell的结果已经产出,我们控制Output Gate的输出直接采用Cell当前的结果就行了,无须使用上一时刻。
最终Output Gate的输出为:
atω=∑i=1Iωiωxti+∑c=1Cωcωstc a_\omega ^t = \sum_{i=1}^{I} \omega_{i\omega} x_i^t + \sum_{c=1}^{C} \omega_{c\omega} s_c^t
btω=f(atω) b_\omega^t = f(a_\omega^t)
这里Cell还可以接受上一个时刻中其他gate链接过来的边,论文中atϕa_\phi^t会增加一项∑Hh=1ωhϕbt−1h\sum_{h=1}^{H} \omega_{h\phi} b_h^{t-1},这里HH是泛指t-1时刻的Cell或三个Gate。
6. Cell Output(cc) 的计算
Cell Output的计算即将Output Gate和Cell做乘积即可。最终Cell Output为:
btc=btωh(stc) b_c^t = b_\omega^t h(s_c^t)
7. 小结
至此,整个Block从Input到Output整个Forward Pass已经结束,其中涉及三个Gate和中间Cell的计算,需要注意的是三个Gate使用的激活函数是ff,而Input的激活函数是gg、Cell输出的激活函数是hh。这里读者需要注意,在整个计算过程中,当前时刻的三个Gate均可以从上一时刻的任意Gate中接受输入,在公式中存在体现,但是在图示中并未画出相应的边。我们可以认为只有上一时刻的Cell才和当前时刻的Cell或三个Gate相连。
三、LSTM的反向传播(Backward Pass)
1. 引入
此处在论文中使用“Backward Pass”一词,但其实即Back Propagation过程,利用链式求导求解整个LSTM中每个权重的梯度。2. 损失函数的选择
为了通用起见,在此我们仅展示多分类问题的损失函数的选择,对于网络的最终输出我们利用softmaxsoftmax方程计算结果属于某一类的概率(此时结果属于k个类别的概率和为1)。p(Ck|x)=yk=eak∑Kk′=1eak′ p(C_k|x) = y_k = \frac{e^a_k}{\sum_{k' =1}^{K} e^a_{k'}}
注意,yky_k对aka_k的偏导为∂yk′∂ak=ykδkk′−ykyk′ \frac{\partial y_{k'}}{\partial a_k}=y_k\delta_{kk'} - y_ky_{k'}(δkk′ \delta_{kk'}当k==k′k==k'时为1,其他为0)
其中,对于网络输出a1,a2,...a_1, a_2,...对应我们可以得到p(C1|x),p(C2|x),...p(C_1|x), p(C_2|x),...,即给定输入xx输出类别为C1,C2,...C_1, C_2,...的概率。
这样损失函数(Loss Function)就很好定义了:对于k∈1,2,...,Kk\in{1,2,...,K},网络输出的类别为k概率为yky_k,而真实值zkz_k:
(x,z)=−lnp(z|x)=−∑k=1Kzklnyk \mathcal{L}(x, z) = -lnp(z|x) = -\sum_{k=1}^{K} z_klny_k
3. 权重的更新
对于神经网络中的每一个权重,我们都需要找到对应的梯度,从而通过不断地用训练样本进行随机梯度下降找到全局最优解,那么首先我们需要知道哪些权重需要更新。一般层次分明的神经网络有input层、hidden层和output层,层与层之间的权重比较直观;但在LSTM中通过公式才能找到对应的权重,和图示中的边并不是一一对应,下面我将LSTM的单个Block中需要更新的权重在图示上标示了出来:
为了方便起见,这里需要申明的是:我们仅考虑上一时刻的Cell仅和当前时刻的Cell和三个Gate相连。
2. Cell Output的梯度
首先我们计算每一个输出类别的梯度:δtk========∂(x,z)∂atk∂(−∑Kk′=1zk′lnyk′)atk−∑k′=1Kzk′∂lnyk′∂atk−∑k′=1Kzk′yk′∂yk′∂atk−∑k′=1Kzk′yk′(ykδkk′−ykyk′)−∑k′=1Kzk′yk′ykδkk′+∑k′=1Kzk′yk′ykyk′−zk+yk∑k′=1Kzk′yk−zk
\begin{align}
\delta_k^t =& \frac{\partial \mathcal{L}(x,z)}{\partial a_k^t}\\
=& \frac{\partial (-\sum_{{k'}=1}^{K} z_{k'}lny_{k'})}{a_k^t}\\
=& -\sum_{k'=1}^{K} z_{k'} \frac{\partial lny_{k'}}{\partial a_k^t}\\
=& -\sum_{k'=1}^{K} \frac{z_{k'}}{y_{k'}} \frac{\partial y_{k'}}{\partial a_k^t}\\
=& -\sum_{k'=1}^{K} \frac{z_{k'}}{y_{k'}} (y_k\delta_{kk'} - y_ky_{k'})\\
=& -\sum_{k'=1}^{K} \frac{z_{k'}}{y_{k'}} y_k\delta_{kk'} + \sum_{k'=1}^{K} \frac{z_{k'}}{y_{k'}} y_ky_{k'}\\
=& -z_k + y_k\sum_{{k'}=1}^K z_{k'}\\
=& y_k - z_k
\end{align}
也即每一个输出类别的梯度仅和其预测值和真实值相关,这样对于Cell Output的梯度则可以通过链式求导法则推导出来:
ϵtc=∂(x,z)∂btc=∑k=1K∂(x,z)∂atk∂atk∂btc=∑k=1Kδtkωck\epsilon_c^t = \frac{\partial \mathcal{L}(x,z)}{\partial b_c^t} = \sum_{k=1}^{K}\frac{\partial \mathcal{L}(x,z)}{\partial a_k^t} \frac{\partial a_k^t}{\partial b_c^t} = \sum_{k=1}^{K} \delta_k^t \omega_{ck}
由于Output还可以连接下一个时刻的一个Cell、三个Gate,那么下一个时刻的一个Cell、三个Gate的梯度则可以传递回当前时刻Output,所以在论文中存在额外项∑Gg=1ωcgδt+1g\sum_{g=1}^G\omega_{cg}\delta_g^{t+1},为简便起见,公式和图示中未包含。
3. Output Gate的梯度
根据链式求导法则,Output Gate的梯度可以由以下公式推导出来:δtω=∂(x,z)∂atω=∂(x,z)∂btc∂btc∂btω∂btω∂atω=ϵtch(stc)f′(atw)\delta_\omega^t = \frac{\partial \mathcal{L}(x,z)}{\partial a_\omega^t} = \frac{\partial \mathcal{L}(x,z)}{\partial b_c^t} \frac{\partial b_c^t}{\partial b_\omega^t} \frac{\partial b_\omega^t}{\partial a_\omega^t}=\epsilon_c^t h(s_c^t)f'(a_w^t)
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Output Gate的梯度写成了f′(atw)∑Cc=1ϵtch(stc)f'(a_w^t) \sum_{c=1}^{C} \epsilon_c^t h(s_c^t),但推导过程一致。推导过程见下图,说明梯度汇总到单个Gate中:
4. Cell的梯度
细心的读者在这里会发现,Cell的计算结构和普遍的神经网络不太一样,让我们首先来回顾一下Cell部分的Forward计算过程:atc=∑i=1Iωicxti a_c^t = \sum_{i=1}^{I} \omega_{ic} x_i^t
stc=btϕst−1c+btιg(atc) s_c^t = b_\phi^t s_c^{t-1} + b_\iota^t g(a_c^t)
输入数据贡献给atca_c^t,而Cell同时能够接受Input Gate和Forget Gate的输入。
这样梯度就直接从Cell向下传递:
δtc=∂(x,z)∂atc=∂(x,z)∂stc∂stc∂atc=∂(x,z)∂stcbtιg′(atc)\delta_c^t = \frac{\partial \mathcal{L}(x,z)}{\partial a_c^t} = \frac{\partial \mathcal{L}(x,z)}{\partial s_c^t} \frac{\partial s_c^t}{\partial a_c^t} =\frac{\partial \mathcal{L}(x,z)}{\partial s_c^t} b_\iota^tg'(a_c^t)
在这里,我们定义States,由于Cell的梯度可以由以下几个计算单元传递回来:
当前时刻的Cell Output
下一个时刻的Cell
下一个时刻的Input Gate
下一个时刻的Output Gate
那么States可以这样求解,上面1~4个能够回传梯度的计算单元和下面公式中一一对应:
ϵts====∂(x,z)∂stc∂t(x,z)∂stc+∂t+1(x,z)∂st+1c∂st+1c∂stc+∂t+1(x,z)∂at+1ι∂at+1ι∂stc+∂t+1(x,z)∂at+1ϕ∂at+1ϕ∂stc(∂(x,z)∂atw∂atw∂stc+∂(x,z)∂btc∂btc∂stc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕδtωωcω+ϵtcbtωh′(stc)+bt+1ϕϵt+1s+ωcιδt+1ι+ωcϕδt+1ϕ
\begin{align}
\epsilon_s^t =& \frac{\partial \mathcal{L}(x,z)}{\partial s_c^t}\\
=& \frac{\partial \mathcal{L}^t(x,z)}{\partial s_c^t} + \frac{\partial \mathcal{L}^{t+1}(x,z)}{\partial s_c^{t+1}}\frac{\partial s_c^{t+1}}{\partial s_c^t} + \frac{\partial \mathcal{L}^{t+1}(x,z)}{\partial a_\iota^{t+1}}\frac{\partial a_\iota^{t+1}}{\partial s_c^t} + \frac{\partial \mathcal{L}^{t+1}(x,z)}{\partial a_\phi^{t+1}}\frac{\partial a_\phi^{t+1}}{\partial s_c^t}\\
=& (\frac{\partial \mathcal{L}(x,z)}{\partial a_w^t}\frac{\partial a_w^t}{\partial s_c^t} + \frac{\partial \mathcal{L}(x,z)}{\partial b_c^t}\frac{\partial b_c^t}{\partial s_c^t}) + b_\phi^{t+1}\epsilon_s^{t+1} + \omega_{c\iota}\delta_\iota^{t+1} + \omega_{c\phi}\delta_\phi^{t+1}\\
=& \delta_\omega^t \omega_{c\omega} + \epsilon_c^t b_\omega^t h'(s_c^t) + b_\phi^{t+1}\epsilon_s^{t+1} + \omega_{c\iota}\delta_\iota^{t+1} + \omega_{c\phi}\delta_\phi^{t+1}
\end{align}
那么:
δtc=ϵtsbtιg′(atc)\delta_c^t = \epsilon_s^t b_\iota^tg'(a_c^t)
细心的读者会发现,论文中∂(x,z)∂btc\frac{\partial \mathcal{L}(x,z)}{\partial b_c^t}并没有求和,这里作者持保留态度,应该存在求和项。
同时由于Cell可以连接到下一个时刻的Forget Gate、Output Gate和Input Gate,那么下一时刻的这三个Gate则可以将梯度传播回来,所以在论文中我们会发现ϵts\epsilon_s^t拥有这三项:bt+1ϕϵt+1sb_\phi^{t+1} \epsilon_s^{t+1}、ωclδt+1ι\omega_{cl}\delta_\iota^{t+1}和ωcϕδt+1ϕ\omega_{c\phi}\delta_\phi^{t+1}。
5. Forget Gate的梯度
Forget Gate的梯度计算就比较简单明了:δtϕ=∂(x,z)∂atϕ=∂(x,z)∂stc∂stc∂btϕ∂btϕ∂atϕ=ϵtsst−1cf′(atϕ)\delta_\phi^t = \frac{\partial \mathcal{L}(x,z)}{\partial a_\phi^t} = \frac{\partial \mathcal{L}(x,z)}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\phi^t} \frac{\partial b_\phi^t}{\partial a_\phi^t}=\epsilon_s^t s_c^{t-1} f'(a_\phi^t)
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Forget Gate的梯度写成了f′(atϕ)∑Cc=1st−1cϵtsf'(a_\phi^t) \sum_{c=1}^{C} s_c^{t-1} \epsilon_s^t,但推导过程一致,说明梯度汇总到单个Gate中。
6. Input Gate的梯度
Input Gate的梯度计算如下:δtι=∂(x,z)∂atι=∂(x,z)∂stc∂stc∂btι∂btι∂atι=ϵtsg(atc)f′(atι)\delta_\iota^t = \frac{\partial \mathcal{L}(x,z)}{\partial a_\iota^t} = \frac{\partial \mathcal{L}(x,z)}{\partial s_c^t} \frac{\partial s_c^t}{\partial b_\iota^t} \frac{\partial b_\iota^t}{\partial a_\iota^t}=\epsilon_s^t g(a_c^t) f'(a_\iota^t)
另外,由于单个Block内可以存在多个memory cell、一个Forget Gate、一个Input Gate和一个Output Gate,论文中将Input Gate的梯度写成了f′(atι)∑Cc=1g(atc)ϵtsf'(a_\iota^t) \sum_{c=1}^{C} g(a_c^t)\epsilon_s^t,但推导过程一致,说明梯度汇总到单个Gate中。
7. 小结
至此,所有的梯度求解已经结束,同样我们将这个Backward Pass的所有公式列出来:剩下的事情即利用梯度去更新每个权重:
Δωn=mΔωn−1−α∂∂ωn\Delta\omega^n = m \Delta\omega^{n-1} - \alpha \frac{\partial\mathcal{L}}{\partial\omega^n}
其中mΔωn−1m \Delta\omega^{n-1}为上一次权重的更新值,且m∈[0,1]m\in[0, 1];而∂∂ωn\frac{\partial\mathcal{L}}{\partial\omega^n}即上面我们求到的每一个梯度。
例如每次更新ωiϕ\omega_{i\phi}的Δ\Delta量即:
Δωniϕ=mΔωn−1iϕ−αxiδtϕ\Delta\omega_{i\phi}^n = m \Delta\omega_{i\phi}^{n-1} - \alpha x_i \delta_\phi^t
其中δtϕ\delta_\phi^t即Forget Gate的梯度。
三、总结
以上就是LSTM中的前向和反向传播的公式推导,在这里作者仅以最简单的单个Cell的场景进行示例。在实际工程实践中,常常会涉及到同一时刻多个Cell且互相之间的Gate存在连接,同时上一个时刻或下一个时刻的Cell和三个Gate之间同样存在复杂的连接关系。
但如果读者能够明晰上述的推导过程,那么无论多复杂都能够迎刃而解了。
毛仁歆
2015年7月31日
相关文章推荐
- serv-u设置被动模式注意的问题
- Convert Sorted List to Binary Search Tree
- 【Ajax技术】Ajax技术概述
- iOS- 如何集成支付宝
- 黑马程序员——14,String相关知识点
- Thinking In Linux C/C++字节对齐详解
- spring零总
- UVa 455 - Periodic Strings
- 输出文本Log
- Javascript实现网络监测的方法
- 空指针nullptr
- 在vc中的调用chm文件的方法
- 连连看
- 一个命令让Win10立即推送升级Win7/Win8.1
- placeholder颜色变化
- 转:RAC中比较replay, replayLast, and replayLazily
- 表单模型+安装目录+侵入表单模型
- 抓包分析TCP三次握手
- autolayout
- 重大校长周绪红寄语毕业生:做好平凡人