SoftmaxLayer and SoftmaxwithLossLayer 代码解读
2016-05-12 22:04
916 查看
SoftmaxLayer and SoftmaxwithLossLayer 代码解读
Wang Xiao
先来看看 SoftmaxWithLoss 在prototext文件中的定义:
再看SoftmaxWithLossLayer的.cpp文件:
接下来是对输入数据进行 reshape 操作:
Wang Xiao
先来看看 SoftmaxWithLoss 在prototext文件中的定义:
layer { name: "loss" type: "SoftmaxWithLoss" bottom: "fc8" bottom: "label" top: "loss" }
再看SoftmaxWithLossLayer的.cpp文件:
#include <algorithm> #include <cfloat> #include <vector> #include "caffe/layers/softmax_loss_layer.hpp" #include "caffe/util/math_functions.hpp" namespace caffe { template <typename Dtype> void SoftmaxWithLossLayer<Dtype>::LayerSetUp( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { LossLayer<Dtype>::LayerSetUp(bottom, top); LayerParameter softmax_param(this->layer_param_); softmax_param.set_type("Softmax"); softmax_layer_ = LayerRegistry<Dtype>::CreateLayer(softmax_param); softmax_bottom_vec_.clear(); softmax_bottom_vec_.push_back(bottom[0]); // 将bottom[0]存入softmax_bottom_vec_; softmax_top_vec_.clear(); softmax_top_vec_.push_back(&prob_); // 将 prob_ 存入 softmax_top_vec_;
softmax_layer_->SetUp(softmax_bottom_vec_, softmax_top_vec_); has_ignore_label_ = // draw the parameter from layer this->layer_param_.loss_param().has_ignore_label(); if (has_ignore_label_) { ignore_label_ = this->layer_param_.loss_param().ignore_label(); } if (!this->layer_param_.loss_param().has_normalization() && this->layer_param_.loss_param().has_normalize()) { normalization_ = this->layer_param_.loss_param().normalize() ? LossParameter_NormalizationMode_VALID : LossParameter_NormalizationMode_BATCH_SIZE; } else { normalization_ = this->layer_param_.loss_param().normalization(); } }
接下来是对输入数据进行 reshape 操作:
template <typename Dtype> void SoftmaxWithLossLayer<Dtype>::Reshape( const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { LossLayer<Dtype>::Reshape(bottom, top); softmax_layer_->Reshape(softmax_bottom_vec_, softmax_top_vec_); softmax_axis_ = bottom[0]->CanonicalAxisIndex(this->layer_param_.softmax_param().axis()); outer_num_ = bottom[0]->count(0, softmax_axis_); inner_num_ = bottom[0]->count(softmax_axis_ + 1); CHECK_EQ(outer_num_ * inner_num_, bottom[1]->count()) << "Number of labels must match number of predictions; " << "e.g., if softmax axis == 1 and prediction shape is (N, C, H, W), " << "label count (number of labels) must be N*H*W, " << "with integer values in {0, 1, ..., C-1}."; if (top.size() >= 2) { // softmax output top[1]->ReshapeLike(*bottom[0]); } }
相关文章推荐
- phpstudy 安装选择,iis+php组合,如何设置伪静态
- C++笔记之关键字explicit
- C++笔记之关键字explicit
- 新的框架,新的感觉ASP.NET MVC 分享一个简单快速适合新手的框架
- 【php安全】eval的禁止【原创】
- [团队项目] Scrum 项目 3.0 SCRUM 流程的步骤2: Spring 计划
- python 模拟126邮箱发送邮件
- 【C/C++】:如何获得一个float数的小数位数?
- Java中的继承
- java多线程-线程的实现
- Spring记录之Bean属性配置、依赖关系及生命周期
- java变量初始化
- C++中类和结构体的介绍
- Spring Mybatis整合
- 【putty】putty、psftp、pscp【原创】
- java--继承
- NSGA-ⅡMATLAB代码(转载)
- 【JDK】:ArrayList和LinkedList源码解析
- GPU 编程与CG 语言之阳春白雪下里巴人——CG学习读书笔记之数学函数(之一)。
- 第7周 C语言程序设计(新2版) 练习1-17 打印长度大于80个字符的所有输入行