您的位置:首页 > 理论基础 > 计算机网络

RNN网络结构及公式推导

2017-07-18 17:37 423 查看
RNN结构如图所示:



Xt∈Rx表示t时刻的输入(Xt是多少维,则这一层有多少个神经元,这里设为x维,图中画的是3维)

ht∈Rh表示t时刻隐层的输出(假设这一层有h个神经元)

yt∈Ry表示t时刻的预测输出

dt∈Ry表示t时刻的期望输出

V∈Rx×h表示从输入层到隐层的权值矩阵

U∈Rh×h表示上一个时刻到这个时刻的权值矩阵

bh∈Rh表示隐层的偏置,其中每一项对应某一神经元的偏置项

W∈Rh×y表示隐层到输出层的权值矩阵

by∈Ry表示输出层的偏置项

正向传播过程:

Xti表示t时刻某个样本第i维的输入,即输入层第i个神经元的输入

t时刻隐层第j个神经元的输入:cthj=∑xi=1XtiVij+∑hs=1ht−1sUsj+bhj

t时刻隐层第j个神经元的输出:htj=f(chj)

t时刻输出层第k个神经元的输入:ctyk=∑hj=1htjWjk+byk

t时刻输出层第k个神经元的输出:ytj=g(cyk)

矩阵表示(只有一个样本的情况):

t时刻隐层的输入,h*1向量,cth=VTXt+UTht−1+bh

t时刻隐层的输出,h*1向量,ht=f(cth)

t时刻输出层的输入,y*1向量,cty=WTht+by

t时刻输出层的输出,y*1向量,yt=g(cty)

反向求导过程:

假设共有p个样本,则t时刻的误差可以定义为:Et=∑p12∥dt−yt∥2,整个网络的误差为E=∑tEt=12∑p∑Tt=1∥dt−yt∥2,

∂E∂W=∑Tt=1∂E∂yt∂yt∂W

∂E∂yt=−(dt−yt)

∂yt∂W=∂yt∂cty∂cty∂W=g′(cty)ht

所以∂E∂W=−∑Tt=1(dt−yt)g′(cty)ht

∂E∂U=∑Tt=1∂E∂ht∂ht∂U,∂E∂V=∑Tt=1∂E∂ht∂ht∂V

由于ht一方面输到yt,一方面输到ht+1,所以它的误差来自两方面:

∂E∂ht=∂E∂yt∂yt∂ht+∂E∂ht+1∂ht+1∂ht=∂E∂yt∂yt∂cty∂cty∂ht+∂E∂ht+1∂ht+1∂ct+1h∂ct+1h∂ht=∂E∂ytg′(cty)W+∂E∂ht+1f′(ct+1h)U

∂ht∂U=∂ht∂cth∂cth∂U=f′(cth)ht−1,∂ht∂V=∂ht∂cth∂cth∂V=f′(cth)Xt

所以

∂E∂U=∑Tt=1[∂E∂ytg′(cty)W+∂E∂ht+1f′(ct+1h)U]f′(cth)ht−1

∂E∂V=∑Tt=1[∂E∂ytg′(cty)W+∂E∂ht+1f′(ct+1h)U]f′(cth)Xt

∂E∂by=∑Tt=1∂E∂yt∂yt∂cty∂cty∂by=−∑Tt=1(dt−yt)g′(cty)

∂E∂bh=∑Tt=1∂E∂ht∂ht∂cth∂cth∂bh=∑Tt=1[∂E∂ytg′(cty)W+∂E∂ht+1f′(ct+1h)U]f′(cth)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐