从Caffe源码分析训练过程
2015-09-09 09:57
399 查看
Caffe库从大体上分为4大类,即Blob,Layer,Net,Solver。本文先从总体上概括Caffe源码的训练过程。
下面开始来具体进行介绍。
先从Caffe.cpp文件中的train()函数开始说起。
1、创建一个SolverParameter solver_param用来保存求解(优化)的一些参数,SolverParameter这个数据结构具体被定义在caffe.proto文件中。
2、Caffe::readProtoFromTextFileDie(“solver.proto”,&solver_param)//读取solver.proto中的求解参数,并将其保存在solver_param中。
3、设置Solver模式—设置GPU/CPU进行训练求解。
4、创建shared_ptr<caffe::Solver<float>> solver(caffe::GetSolver<float>(solver_param))对象,这里是一种简单工厂的设计模式,caffe::GetSolver<float>(solver_param),根据solver_param参数的不同,创建不同的对象实例。在这里caffe默认创建一个SGD的对象实例,即SGD(SGD派生类)求解方式。
在这里具体实现有:
1、调用Solver类构造函数,调用Init()函数;
2、Init()函数中调用InitTrainNet(),这个函数创建训练网络Net,首先创建一个Net对象net_,这个net_对象会根据自己配置的网络结构文件创建不同类型的layer. InitTestNet()与InitTrainNet()类似。
3、调用PreSolve函数,这个函数功能是将参数push_back到history_,update_,temp_这三个vector中。
5、solver->Solve求解,主要是反复调用Net::ForwardBackward函数,因为误差来自于loss层,因此Net::ForwardBackward()只有一个参数。
Net::Forward函数函数,调用Net::ForwardPrefilled(), Net::ForwardPrefilled()调用Net::ForwardFromTo—这个函数将调用layers::Forward函数,因为在Net创建时,layers_对象就保存了每一层类型,因此每一个layers_[i]调用属于它自己类型的Forward和Backward函数(虚函数,多态)。
第一次写博客,首先先把caffe的整体训练的过程理清楚,接下来将写blob,layers,net,solver的具体实现啦。
下面开始来具体进行介绍。
先从Caffe.cpp文件中的train()函数开始说起。
1、创建一个SolverParameter solver_param用来保存求解(优化)的一些参数,SolverParameter这个数据结构具体被定义在caffe.proto文件中。
2、Caffe::readProtoFromTextFileDie(“solver.proto”,&solver_param)//读取solver.proto中的求解参数,并将其保存在solver_param中。
3、设置Solver模式—设置GPU/CPU进行训练求解。
4、创建shared_ptr<caffe::Solver<float>> solver(caffe::GetSolver<float>(solver_param))对象,这里是一种简单工厂的设计模式,caffe::GetSolver<float>(solver_param),根据solver_param参数的不同,创建不同的对象实例。在这里caffe默认创建一个SGD的对象实例,即SGD(SGD派生类)求解方式。
在这里具体实现有:
1、调用Solver类构造函数,调用Init()函数;
2、Init()函数中调用InitTrainNet(),这个函数创建训练网络Net,首先创建一个Net对象net_,这个net_对象会根据自己配置的网络结构文件创建不同类型的layer. InitTestNet()与InitTrainNet()类似。
3、调用PreSolve函数,这个函数功能是将参数push_back到history_,update_,temp_这三个vector中。
5、solver->Solve求解,主要是反复调用Net::ForwardBackward函数,因为误差来自于loss层,因此Net::ForwardBackward()只有一个参数。
<code class="hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">Dtype Net::ForwardBackward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>* > & bottom) { Dtype loss; Forward(bottom, &loss);<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//误差来自于loss层</span> Backward(); <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> loss; } </code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>
Net::Forward函数函数,调用Net::ForwardPrefilled(), Net::ForwardPrefilled()调用Net::ForwardFromTo—这个函数将调用layers::Forward函数,因为在Net创建时,layers_对象就保存了每一层类型,因此每一个layers_[i]调用属于它自己类型的Forward和Backward函数(虚函数,多态)。
第一次写博客,首先先把caffe的整体训练的过程理清楚,接下来将写blob,layers,net,solver的具体实现啦。
相关文章推荐
- 从源码安装Mysql/Percona 5.5
- 浅析Ruby的源代码布局及其编程风格
- asp.net 抓取网页源码三种实现方法
- JS小游戏之仙剑翻牌源码详解
- JS小游戏之宇宙战机源码详解
- jQuery源码分析之jQuery中的循环技巧详解
- 本人自用的global.js库源码分享
- java中原码、反码与补码的问题分析
- PHP网页游戏学习之Xnova(ogame)源码解读(六)
- C#获取网页HTML源码实例
- PHP网页游戏学习之Xnova(ogame)源码解读(八)
- PHP网页游戏学习之Xnova(ogame)源码解读(四)
- JS小游戏之极速快跑源码详解
- JS小游戏之象棋暗棋源码详解
- 基于Android设计模式之--SDK源码之策略模式的详解
- Android游戏源码分享之2048
- C语言借助EasyX实现的生命游戏源码
- C实现的非阻塞方式命令行端口扫描器源码
- PHP网页游戏学习之Xnova(ogame)源码解读(七)
- PHP网页游戏学习之Xnova(ogame)源码解读(一)