您的位置:首页 > 其它

CTC学习笔记(五) eesen训练源码

2016-11-24 20:58 323 查看
essen源码参考https://github.com/yajiemiao/eesen,这里简单说一下涉及到训练前后向的核心算法源码实现。

以单句训练为准(多句并行类似),用到的变量

变量含义
phones_num最后一层输出节点个数,对应于|phones|+1
labels_num一句话对应的标注扩展blank以后的个数,比如”123”扩展为”b1b2b3b”
frames_num一句话对应的总的帧数,对应于时间t
ytk最后一层输出
atksoftmax层的输入

CTC error

ctc.Eval(net_out, targets, &obj_diff);


涉及到的变量的维度:

变量维度
net_outframes_num*phones_num
alpha/betaframes_num*labes_num
ctc_errorframes_num*phones_num
本来可以使用最终的公式求出对atk的error,代码中却分成了两部求解,可能逻辑上能体现出error反向传播的过程,但是实际感觉没有必要。

计算关于ytk的error

ctc_err_.ComputeCtcError(alpha_, beta_, net_out, label_expand_, pzx);


参考[1]给出的公式(15)

计算关于utk的error

ctc_err_.MulElements(net_out);
CuVector<BaseFloat> row_sum(num_frames, kSetZero);
row_sum.AddColSumMat(1.0, ctc_err_, 0.0);
CuMatrix<BaseFloat> net_out_tmp(net_out);
net_out_tmp.MulRowsVec(row_sum);
diff->CopyFromMat(ctc_err_);
diff->AddMat(-1.0, net_out_tmp);


主要是ytk对atk进行求导,推导参考前面的博客,结论如下:

∂L∂atk=∑k′∂L∂ytk′ytk′δkk′−∑k′∂L∂ytk′ytk′ytk

=∂L∂ytkytk−∑k′∂L∂ytk′ytk′ytk

注意上式最后一项有一个求和的过程,即将t时刻对应的ytk的所有节点的error累加。

沿网络反向传播error

变量含义
x每一层的输入
y每一层的输出
d_x关于x的error
d_y关于y的error
dim_in输入维度
dim_out输出维度
W每一层对应的参数矩阵
error依次经过affine-trans-layer和多层lstm-layer,每一层有两个目的:

- 求d_x: 将error传递到每一层的输入,以往后继续传播

- 求ΔW: 计算当前层的参数的error,以根据error更新参数

affine layer

变量维度
x/d_xframes_num*dim_in
y/d_yframes_num*dim_out
Wdim_out*dim_in

前向

y=x∗WT

后向

ΔW(t)=d_yT∗x+momentum∗ΔW(t−1)

这里参数更新有一个求和的过程,把所有时刻对应的ΔW进行累加,相当于把所有时间的error数据进行了求和作为最终的error。

lstm layer



参考[2],eesen采用的lstm单元如上图,但是代码中变量的含义和论文中不一致。

前向

it=δ(xtWTix+mt−1WTim+ct−1WTic+bi)

ft=δ(xtWTfx+mt−1WTfm+ct−1WTfc+bi)

gt=δ(xtWTcx+mt−1WTcm+bc)

ct=ft⊙ct−1+it⊙gt

ot=δ(xtWTox+mt−1WTom+ctWToc+bo)

ht=ϕ(ct)

mt=ot⊙ht

有两方面的并行

- gifo合并成一个矩阵

- 批量计算输入x(不依赖于t),然后再分帧计算其他变量

后向

Di=∂L∂(xtWTix+mt−1WTim+ct−1WTic+bi)

Df=∂L∂(xtWTfx+mt−1WTfm+ct−1WTfc+bi)

Dg=∂L∂(xtWTcx+mt−1WTcm+bc)

Do=∂L∂(xtWTox+mt−1WTom+ctWToc+bo)

Dc=∂L∂(ft⊙ct−1+it⊙gt)

有两个注意点

- mt的error除了来自于t时刻的error,还有来至于t+1时刻的Di/Df/Do/Dg

- ct的error除了来自于t时刻的error,还有来自于t+1时刻的Di/Df/Dc

参考文献

[1].Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

[2].Essen: End-to-End Speech Recognition Using Deep Rnn Models and WFST-Based Decoding
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: