您的位置:首页 > Web前端

Caffe 代码解读之全连接层concat layer

2016-03-13 11:43 429 查看
今天,我们看一下caffe的拼接层,即将两个或多个layer进行拼接。

首先,看一下caffe官方文档。



同其他layer一样,分为setup、reshape、Forward_cpu、Backward_cpu。

//concat_layer 用来实现两个或者多个blob的链接,即多输入一输出
//支持在num 维度上的链接(concat_dim = 0 : (n1+n2+...+nk)∗c∗h∗w )
//和channel维度上的链接(concat_dim = 1 : n∗(c1+c2+...+ck)∗h∗w)。

//axis ,dim :0 为 num 维度链接,1 为 channel 维度链接
//这里需要给出axis或concat_dim
template <typename Dtype>
void ConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const ConcatParameter& concat_param = this->layer_param_.concat_param();
CHECK(!(concat_param.has_axis() && concat_param.has_concat_dim()))
<< "Either axis or concat_dim should be specified; not both.";
}

template <typename Dtype>
void ConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
//获取axis,确定拼接哪一维度
const int num_axes = bottom[0]->num_axes();
const ConcatParameter& concat_param = this->layer_param_.concat_param();
//以下都在获取、判断axis的维度
if (concat_param.has_concat_dim()) {
concat_axis_ = static_cast<int>(concat_param.concat_dim());
// Don't allow negative indexing for concat_dim, a uint32 -- almost
// certainly unintended.
CHECK_GE(concat_axis_, 0) << "casting concat_dim from uint32 to int32 "
<< "produced negative result; concat_dim must satisfy "
<< "0 <= concat_dim < " << kMaxBlobAxes;
CHECK_LT(concat_axis_, num_axes) << "concat_dim out of range.";
} else {
concat_axis_ = bottom[0]->CanonicalAxisIndex(concat_param.axis());
}
// Initialize with the first blob.
//这里有一点需要解释,可以看到,bottom类型为 vector<Blob<Dtype>*>,这里只需要使用bottom[0]
//给shape赋值就好,其实botom本身就是一个Blob的vector
//比如:我要将两个layer拼接,那么久有bottom[0]以及bottom[1]
vector<int> top_shape = bottom[0]->shape();
//concat_axis_ = 0 : num_concats_=num;concat_axis_ = 1 : num_concats_=num x channel;
num_concats_ = bottom[0]->count(0, concat_axis_);
//concat_axis_ = 0 : concat_input_size_=channel x height x width;
//concat_axis_ = 1 : concat_input_size_=height x width;
concat_input_size_ = bottom[0]->count(concat_axis_ + 1);

int bottom_count_sum = bottom[0]->count();
//检测num_axes拼接的层是否相同,num_axes为维度信息
for (int i = 1; i < bottom.size(); ++i) {
CHECK_EQ(num_axes, bottom[i]->num_axes())
<< "All inputs must have the same #axes.";
for (int j = 0; j < num_axes; ++j) {
if (j == concat_axis_) { continue; }
CHECK_EQ(top_shape[j], bottom[i]->shape(j))
<< "All inputs must have the same shape, except at concat_axis.";
}
bottom_count_sum += bottom[i]->count();
top_shape[concat_axis_] += bottom[i]->shape(concat_axis_);
}
top[0]->Reshape(top_shape);
CHECK_EQ(bottom_count_sum, top[0]->count());
}


1、这里有一点需要解释,可以看到,bottom类型为 vector blob,这里只需要使用bottom[0]给shape赋值就好,其实bottom本身就是一个Blob的vector。

2、CHECK_**,这里给小白们解释一下,就是判断是否相等、小于、大于



3、 count,这看到有好多的count函数,这些函数在blob层实现,解释如下:

inline int count() const { return count_; }

/**
* @brief Compute the volume of a slice; i.e., the product of dimensions
*        among a range of axes.
*
* @param start_axis The first axis to include in the slice.
*
* @param end_axis The first axis to exclude from the slice.
*/
inline int count(int start_axis, int end_axis) const {
CHECK_LE(start_axis, end_axis);
CHECK_GE(start_axis, 0);
CHECK_GE(end_axis, 0);
CHECK_LE(start_axis, num_axes());
CHECK_LE(end_axis, num_axes());
int count = 1;
for (int i = start_axis; i < end_axis; ++i) {
count *= shape(i);
}
return count;
}
/**
* @brief Compute the volume of a slice spanning from a particular first
*        axis to the final axis.
*
* @param start_axis The first axis to include in the slice.
*/
inline int count(int start_axis) const {
return count(start_axis, num_axes());
}


前向传播就是layer的拼接

template <typename Dtype>
void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
Dtype* top_data = top[0]->mutable_cpu_data();
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
//遍历所有输入bottom
for (int i = 0; i < bottom.size(); ++i) {
const Dtype* bottom_data = bottom[i]->cpu_data();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
//把 各个bottom data 拷贝到输出 top data 的对应位置
for (int n = 0; n < num_concats_; ++n) {
//case 0:num x channel x h x w;case 1: channel x h x w
//case 0:bottom_data + n x num x channel x h x w ;
//case 1:bottom_data + n x channel x h x w ;
caffe_copy(bottom_concat_axis * concat_input_size_,
bottom_data + n * bottom_concat_axis * concat_input_size_,
top_data + (n * top_concat_axis + offset_concat_axis)
* concat_input_size_);
}
offset_concat_axis += bottom_concat_axis;
}
}


反向传播,就是layer层之间diff和data的传播

//反向传播就是对每一个bottom的 diff 做和 data 相同的链接
template <typename Dtype>
void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
const Dtype* top_diff = top[0]->cpu_diff();
int offset_concat_axis = 0;
const int top_concat_axis = top[0]->shape(concat_axis_);
for (int i = 0; i < bottom.size(); ++i) {
if (!propagate_down[i]) { continue; }
Dtype* bottom_diff = bottom[i]->mutable_cpu_diff();
const int bottom_concat_axis = bottom[i]->shape(concat_axis_);
for (int n = 0; n < num_concats_; ++n) {
caffe_copy(bottom_concat_axis * concat_input_size_, top_diff +
(n * top_concat_axis + offset_concat_axis) * concat_input_size_,
bottom_diff + n * bottom_concat_axis * concat_input_size_);
}
offset_concat_axis += bottom_concat_axis;
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: