CTC学习笔记(五) eesen训练源码
2016-11-24 20:58
323 查看
essen源码参考https://github.com/yajiemiao/eesen,这里简单说一下涉及到训练前后向的核心算法源码实现。
以单句训练为准(多句并行类似),用到的变量
涉及到的变量的维度:
本来可以使用最终的公式求出对atk的error,代码中却分成了两部求解,可能逻辑上能体现出error反向传播的过程,但是实际感觉没有必要。
参考[1]给出的公式(15)
主要是ytk对atk进行求导,推导参考前面的博客,结论如下:
∂L∂atk=∑k′∂L∂ytk′ytk′δkk′−∑k′∂L∂ytk′ytk′ytk
=∂L∂ytkytk−∑k′∂L∂ytk′ytk′ytk
注意上式最后一项有一个求和的过程,即将t时刻对应的ytk的所有节点的error累加。
error依次经过affine-trans-layer和多层lstm-layer,每一层有两个目的:
- 求d_x: 将error传递到每一层的输入,以往后继续传播
- 求ΔW: 计算当前层的参数的error,以根据error更新参数
这里参数更新有一个求和的过程,把所有时刻对应的ΔW进行累加,相当于把所有时间的error数据进行了求和作为最终的error。
参考[2],eesen采用的lstm单元如上图,但是代码中变量的含义和论文中不一致。
ft=δ(xtWTfx+mt−1WTfm+ct−1WTfc+bi)
gt=δ(xtWTcx+mt−1WTcm+bc)
ct=ft⊙ct−1+it⊙gt
ot=δ(xtWTox+mt−1WTom+ctWToc+bo)
ht=ϕ(ct)
mt=ot⊙ht
有两方面的并行
- gifo合并成一个矩阵
- 批量计算输入x(不依赖于t),然后再分帧计算其他变量
Df=∂L∂(xtWTfx+mt−1WTfm+ct−1WTfc+bi)
Dg=∂L∂(xtWTcx+mt−1WTcm+bc)
Do=∂L∂(xtWTox+mt−1WTom+ctWToc+bo)
Dc=∂L∂(ft⊙ct−1+it⊙gt)
有两个注意点
- mt的error除了来自于t时刻的error,还有来至于t+1时刻的Di/Df/Do/Dg
- ct的error除了来自于t时刻的error,还有来自于t+1时刻的Di/Df/Dc
[2].Essen: End-to-End Speech Recognition Using Deep Rnn Models and WFST-Based Decoding
以单句训练为准(多句并行类似),用到的变量
变量 | 含义 |
---|---|
phones_num | 最后一层输出节点个数,对应于|phones|+1 |
labels_num | 一句话对应的标注扩展blank以后的个数,比如”123”扩展为”b1b2b3b” |
frames_num | 一句话对应的总的帧数,对应于时间t |
ytk | 最后一层输出 |
atk | softmax层的输入 |
CTC error
ctc.Eval(net_out, targets, &obj_diff);
涉及到的变量的维度:
变量 | 维度 |
---|---|
net_out | frames_num*phones_num |
alpha/beta | frames_num*labes_num |
ctc_error | frames_num*phones_num |
计算关于ytk的error
ctc_err_.ComputeCtcError(alpha_, beta_, net_out, label_expand_, pzx);
参考[1]给出的公式(15)
计算关于utk的error
ctc_err_.MulElements(net_out); CuVector<BaseFloat> row_sum(num_frames, kSetZero); row_sum.AddColSumMat(1.0, ctc_err_, 0.0); CuMatrix<BaseFloat> net_out_tmp(net_out); net_out_tmp.MulRowsVec(row_sum); diff->CopyFromMat(ctc_err_); diff->AddMat(-1.0, net_out_tmp);
主要是ytk对atk进行求导,推导参考前面的博客,结论如下:
∂L∂atk=∑k′∂L∂ytk′ytk′δkk′−∑k′∂L∂ytk′ytk′ytk
=∂L∂ytkytk−∑k′∂L∂ytk′ytk′ytk
注意上式最后一项有一个求和的过程,即将t时刻对应的ytk的所有节点的error累加。
沿网络反向传播error
变量 | 含义 |
---|---|
x | 每一层的输入 |
y | 每一层的输出 |
d_x | 关于x的error |
d_y | 关于y的error |
dim_in | 输入维度 |
dim_out | 输出维度 |
W | 每一层对应的参数矩阵 |
- 求d_x: 将error传递到每一层的输入,以往后继续传播
- 求ΔW: 计算当前层的参数的error,以根据error更新参数
affine layer
变量 | 维度 |
---|---|
x/d_x | frames_num*dim_in |
y/d_y | frames_num*dim_out |
W | dim_out*dim_in |
前向
y=x∗WT后向
ΔW(t)=d_yT∗x+momentum∗ΔW(t−1)这里参数更新有一个求和的过程,把所有时刻对应的ΔW进行累加,相当于把所有时间的error数据进行了求和作为最终的error。
lstm layer
参考[2],eesen采用的lstm单元如上图,但是代码中变量的含义和论文中不一致。
前向
it=δ(xtWTix+mt−1WTim+ct−1WTic+bi)ft=δ(xtWTfx+mt−1WTfm+ct−1WTfc+bi)
gt=δ(xtWTcx+mt−1WTcm+bc)
ct=ft⊙ct−1+it⊙gt
ot=δ(xtWTox+mt−1WTom+ctWToc+bo)
ht=ϕ(ct)
mt=ot⊙ht
有两方面的并行
- gifo合并成一个矩阵
- 批量计算输入x(不依赖于t),然后再分帧计算其他变量
后向
Di=∂L∂(xtWTix+mt−1WTim+ct−1WTic+bi)Df=∂L∂(xtWTfx+mt−1WTfm+ct−1WTfc+bi)
Dg=∂L∂(xtWTcx+mt−1WTcm+bc)
Do=∂L∂(xtWTox+mt−1WTom+ctWToc+bo)
Dc=∂L∂(ft⊙ct−1+it⊙gt)
有两个注意点
- mt的error除了来自于t时刻的error,还有来至于t+1时刻的Di/Df/Do/Dg
- ct的error除了来自于t时刻的error,还有来自于t+1时刻的Di/Df/Dc
参考文献
[1].Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks[2].Essen: End-to-End Speech Recognition Using Deep Rnn Models and WFST-Based Decoding
相关文章推荐
- 【深度学习】笔记15 微软官方源码caffe的第一个测例Mnist训练运行配置
- CTC学习笔记(二) 训练和公式推导
- Faster R-CNN 训练源码学习笔记
- CTC学习笔记(二) 训练和公式推导
- prototype.js 源码学习笔记(一)
- LDD3源码学习笔记之scull_pipe转
- 学习笔记:解读CppUnit源码3
- Ubuntu学习笔记(1)---编译源码包
- 2410 TFTP源码 学习笔记414757749
- jQuery源码学习笔记一
- 学习笔记:解读CppUnit源码2
- JAVA虚拟机源码学习笔记之二
- jQuery源码学习笔记二(转)
- jQuery源码学习笔记三(转)
- jQuery源码学习笔记五(转)
- 学习笔记:解读CppUnit源码7
- (源码实例)通过层DIV实现,当鼠标放在链接上面,显示图片及文字 - 流星絮语 JAVA学习笔记 - CSDNBlog
- shell脚本学习笔记(一)闹钟的源码
- spring学习笔记之DispatcherServlet源码解读
- jQuery源码学习笔记一(转)