LSTM推导 源码分析
2015-08-24 21:36
260 查看
LSTM推导
说是推导,基本上没有一个公式。注重理解。神经网络后向推导
cs231上有一篇关于非常好的文章, 讲得非常好。一个例子:
f(x,y)=x+σ(y)σ(x)+(x+y)2f(x,y) = \frac{x + \sigma(y)}{\sigma(x) + (x+y)^2}
x = 3 # example values y = -4 # forward pass sigy = 1.0 / (1 + math.exp(-y)) # sigmoid in numerator #(1) num = x + sigy # numerator #(2) sigx = 1.0 / (1 + math.exp(-x)) # sigmoid in denominator #(3) xpy = x + y #(4) xpysqr = xpy**2 #(5) den = sigx + xpysqr # denominator #(6) invden = 1.0 / den #(7) f = num * invden # done! #(8)
对应的后向传播为:
# backprop f = num * invden dnum = invden # gradient on numerator #(8) dinvden = num #(8) # backprop invden = 1.0 / den dden = (-1.0 / (den**2)) * dinvden #(7) # backprop den = sigx + xpysqr dsigx = (1) * dden #(6) dxpysqr = (1) * dden #(6) # backprop xpysqr = xpy**2 dxpy = (2 * xpy) * dxpysqr #(5) # backprop xpy = x + y dx = (1) * dxpy #(4) dy = (1) * dxpy #(4) # backprop sigx = 1.0 / (1 + math.exp(-x)) dx += ((1 - sigx) * sigx) * dsigx # Notice += !! See notes below #(3) # backprop num = x + sigy dx += (1) * dnum #(2) dsigy = (1) * dnum #(2) # backprop sigy = 1.0 / (1 + math.exp(-y)) dy += ((1 - sigy) * sigy) * dsigy #(1) # done! phew
代码分析
这是karpathy的lstm源码分析。代码中#sooda是我的注释。
前向更新公式为:
依照上文的后向传播的推导方式,可以得到,
前向更新:
后向更新:
注意点
通过观察公式1到4, 发现所有的乘机因子为x、h,互相没有依赖,可以并行化。利用向量化进行加速IFOG指的是Input,Forget, Output, Cell Gate的计算值。IFOGf是IFOG经过激活函数后的激活值. 并以此为顺序。o-d表示input gate, d-2d表示forget gate, 2d-3d表示output gate, 3d-end 表示cell gate
WLSTM保存的实际上是所有这些门相对于输入+隐藏层+偏置的权值。
后向传播从最后一个开始求偏导, 按照从后向前,按部就班即可,不需要跨步骤考虑
cache是为了保存后向传播所需要的值
完整代码gist
相关文章推荐
- AndroidStudio NDK 学习之接受Java传入的字符串
- C语言字符串反转
- iOS开发之网络篇-CocoaPods的安装 EI Capitan 10.11 之前的方式
- msstdfmt.dll缺失报错
- Class Imbalance Problem
- LeetCode263——Ugly Number
- iOS中将汉字转换成拼音的方法
- 解释(n&(n-1))==0的具体含义
- C语言基础--循环 递归打印乘法表
- bootstrap-js(2)下拉菜单
- 可变参数的实现
- PhotoView的异常问题
- 内联函数详解--C++
- 使用Android Studio和Gradle编译NDK项目之Experimental Plugin User Guide
- getopt、getopt_long、getopt_long_only使用实例
- iOS图片轮播
- 大龄屌丝自学笔记--Java零基础到菜鸟--010
- 不等式估计
- 随机生成1-100之间的数,并无一重复的存入长度为100的数组中
- ubuntu 14.04安装opencv3.0.0