caffe源码分析--Blob类
2016-08-25 15:58
316 查看
转自:http://blog.csdn.net/lingerlanlan/article/details/24379689
数据成员
构造函数
其它函数
数据成员
protected: shared_ptr<SyncedMemory> data_; //data数据,指向SyncedMemory类的智能指针 shared_ptr<SyncedMemory> diff_; //参数更新量 shared_ptr<SyncedMemory> shape_data_; //数据维度 vector<int> shape_; //数据维度 int count_; //数据量 int capacity_; //数据量
构造函数
Blob(): data_(), diff_(), count_(0), capacity_(0){}
explicit Blob(const int num, const int channels, const int height, const int width);
Blob<Dtype>::Blob(const vector<int>& shape)//一般用这个 : capacity_(0) { Reshape(shape); }
template <typename Dtype> void Blob<Dtype>::Reshape(const vector<int>& shape) { CHECK_LE(shape.size(), kMaxBlobAxes); count_ = 1; shape_.resize(shape.size()); if (!shape_data_ || shape_data_->size() < shape.size() * sizeof(int)) { shape_data_.reset(new SyncedMemory(shape.size() * sizeof(int))); } int* shape_data = static_cast<int*>(shape_data_->mutable_cpu_data()); for (int i = 0; i < shape.size(); ++i) { CHECK_GE(shape[i], 0); CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX"; count_ *= shape[i]; shape_[i] = shape[i]; shape_data[i] = shape[i]; } if (count_ > capacity_) { capacity_ = count_; data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype))); } }
void Reshape(const BlobShape& shape); void ReshapeLike(const Blob& other);
其它函数
inline const vector<int>& shape() const { returnshape_; } inline int shape(int index) const { return shape_[CanonicalAxisIndex(index)]; } inline int num_axes() const { return shape_.size(); } inline int count() const { return count_; }
inline int count(int start_axis, int end_axis) const {} //返回start轴到end轴的数据量,区间左闭右开
//这四个函数过时了,使用shape(i)吧 inline int num() const inline int channels() const inline int height() const inline int width() const
// 返回偏移量 inline int offset(const int n, const int c = 0, const int h = 0, const int w = 0) inline int offset(const vector<int>& indices) const //用这个
//拷贝source数据 template <typename Dtype> void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) { if (source.count() != count_ || source.shape() != shape_) { if (reshape) { ReshapeLike(source); } else { LOG(FATAL) << "Trying to copy blobs of different sizes."; } } switch (Caffe::mode()) { case Caffe::GPU: if (copy_diff) {//copy_diff为真,则拷贝diff;否则拷贝data caffe_copy(count_, source.gpu_diff(), static_cast<Dtype*>(diff_->mutable_gpu_data())); } else { caffe_copy(count_, source.gpu_data(), static_cast<Dtype*>(data_->mutable_gpu_data())); } break; case Caffe::CPU: if (copy_diff) { caffe_copy(count_, source.cpu_diff(), static_cast<Dtype*>(diff_->mutable_cpu_data())); } else { caffe_copy(count_, source.cpu_data(), static_cast<Dtype*>(data_->mutable_cpu_data())); } break; default: LOG(FATAL) << "Unknown caffe mode."; } }
<pre name="code" class="cpp">//写入bolb template <> void Blob<double>::ToProto(BlobProto* proto, bool write_diff) const { proto->clear_shape(); for (int i = 0; i < shape_.size(); ++i) { proto->mutable_shape()->add_dim(shape_[i]); } proto->clear_double_data(); proto->clear_double_diff(); const double* data_vec = cpu_data(); for (int i = 0; i < count_; ++i) { proto->add_double_data(data_vec[i]); } if (write_diff) { const double* diff_vec = cpu_diff(); for (int i = 0; i < count_; ++i) { proto->add_double_diff(diff_vec[i]); } }
//访问(取)数据 inline Dtype data_at(const int n, const int c, const int h, const int w) const inline Dtype diff_at(const int n, const int c, const int h, const int w) const inline Dtype data_at(const vector<int>& index) const inline Dtype diff_at(const vector<int>& index)
inline const shared_ptr<SyncedMemory>& data() const //return data_智能指针 inline const shared_ptr<SyncedMemory>& diff() const //return diff_智能指针
const int* Blob<Dtype>::gpu_shape() const //return (const int*)shape_data_->gpu_data() const Dtype* Blob<Dtype>::cpu_data() const //return (const Dtype*)data_->cpu_data() const Dtype* Blob<Dtype>::gpu_data() //return (const Dtype*)data_->gpu_data() const Dtype* Blob<Dtype>::cpu_diff() const //return (const Dtype*)diff_->cpu_data() const Dtype* Blob<Dtype>::gpu_diff() // return (const Dtype*)diff_->gpu_data() Dtype* Blob<Dtype>::mutable_cpu_data() //return static_cast<Dtype*>(data_->mutable_cpu_data()) Dtype* Blob<Dtype>::mutable_gpu_data() //return static_cast<Dtype*>(data_->mutable_gpu_data()) Dtype* Blob<Dtype>::mutable_cpu_diff() // return static_cast<Dtype*>(diff_->mutable_cpu_data()) Dtype* Blob<Dtype>::mutable_gpu_diff() //return static_cast<Dtype*>(diff_->mutable_gpu_data())
//将other的data_和diff_赋给blob void Blob<Dtype>::ShareData(const Blob& other) void Blob<Dtype>::ShareDiff(const Blob& other)
<pre name="code" class="cpp">//更新权重
<pre name="code" class="cpp">template <typename Dtype> void Blob<Dtype>::Update() { // We will perform update based on where the data is located. switch (data_->head()) { case SyncedMemory::HEAD_AT_CPU: // perform computation on CPU caffe_axpy<Dtype>(count_, Dtype(-1), static_cast<const Dtype*>(diff_->cpu_data()), static_cast<Dtype*>(data_->mutable_cpu_data())); break; case SyncedMemory::HEAD_AT_GPU: case SyncedMemory::SYNCED: #ifndef CPU_ONLY // perform computation on GPU caffe_gpu_axpy<Dtype>(count_, Dtype(-1), static_cast<const Dtype*>(diff_->gpu_data()), static_cast<Dtype*>(data_->mutable_gpu_data())); #else NO_GPU; #endif break; default: LOG(FATAL) << "Syncedmem not initialized."; } }
Dtype asum_data() const; //返回data的第一范数 Dtype asum_diff() const; //返回diff的第一范数 Dtype sumsq_data() const; //返回data的第二范数 Dtype sumsq_diff() const; //返回diff的第二范数 //放缩data和diff void scale_data(Dtype scale_factor); void scale_diff(Dtype scale_factor);
bool ShapeEquals(const BlobProto& other); //判断各维是否相等
相关文章推荐
- caffe源码分析--Blob类代码研究
- caffe源码分析--Blob类代码研究
- Caffe源码(十一):io.cpp 分析
- 从Caffe源码分析训练过程
- caffe中HingeLossLayer层原理以及源码分析
- Caffe源码(六): pooling_layer 分析
- caffe源码 layer分析
- caffe源码分析--poolinger_layer.cpp
- 从Caffe源码分析训练过程
- caffe源码分析--softmax_layer.cpp
- 从Caffe源码分析训练过程
- 从Caffe源码分析训练过程
- caffe源码分析--math_functions.cu代码研究
- 从Caffe源码分析训练过程
- 从Caffe源码分析训练过程
- 从Caffe源码分析训练过程
- caffe全连接层(INNER_PRODUCT)源码注释与分析
- caffe源码分析--data_layer.cpp
- caffe源码分析--SyncedMemory类代码研究
- 从Caffe源码分析训练过程