您的位置:首页 > 编程语言 > Lua

SSD 算法detection_evaluate_layer解读

2017-06-08 15:27 701 查看
代码位置

caffe/include/caffe/layers/detection_evaluate_layer.hpp

#ifndef CAFFE_DETECTION_EVALUATE_LAYER_HPP_
#define CAFFE_DETECTION_EVALUATE_LAYER_HPP_

#include <utility>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Generate the detection evaluation based on DetectionOutputLayer and
* ground truth bounding box labels.
*
* Intended for use with MultiBox detection method.
*
* NOTE: does not implement Backwards operation.
*/
template <typename Dtype>
class DetectionEvaluateLayer : public Layer<Dtype> {
public:
explicit DetectionEvaluateLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline const char* type() const { return "DetectionEvaluate"; }
virtual inline int ExactBottomBlobs() const { return 2; }
virtual inline int ExactNumTopBlobs() const { return 1; }

protected:
/**
* @brief Evaluate the detection output.
*
* @param bottom input Blob vector (exact 2)
*   -# @f$ (1 \times 1 \times N \times 7) @f$
*      N detection results.
*   -# @f$ (1 \times 1 \times M \times 7) @f$
*      M ground truth.
* @param top Blob vector (length 1)
*   -# @f$ (1 \times 1 \times N \times 4) @f$
*      N is the number of detections, and each row is:
*      [image_id, label, confidence, true_pos, false_pos]
*/
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
/// @brief Not implemented
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
NOT_IMPLEMENTED;
}

int num_classes_;
int background_label_id_;
float overlap_threshold_;
bool evaluate_difficult_gt_;
vector<pair<int, int> > sizes_;
int count_;
bool use_normalized_bbox_;
bool has_resize_;
ResizeParameter resize_param_;
};

}  // namespace caffe

#endif  // CAFFE_DETECTION_EVALUATE_LAYER_HPP_


caffe/src/caffe/layers/detection_evaluate_layer.cpp

#include <algorithm>
#include <map>
#include <string>
#include <vector>

#include "caffe/layers/detection_evaluate_layer.hpp"
#include "caffe/util/bbox_util.hpp"

namespace caffe {

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::LayerSetUp(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const DetectionEvaluateParameter& detection_evaluate_param =
this->layer_param_.detection_evaluate_param();
CHECK(detection_evaluate_param.has_num_classes())
<< "Must provide num_classes.";
num_classes_ = detection_evaluate_param.num_classes();
background_label_id_ = detection_evaluate_param.background_label_id();
overlap_threshold_ = detection_evaluate_param.overlap_threshold();
CHECK_GT(overlap_threshold_, 0.) << "overlap_threshold must be non negative.";
evaluate_difficult_gt_ = detection_evaluate_param.evaluate_difficult_gt();
if (detection_evaluate_param.has_name_size_file()) {
string name_size_file = detection_evaluate_param.name_size_file();
std::ifstream infile(name_size_file.c_str());
CHECK(infile.good())
<< "Failed to open name size file: " << name_size_file;
// The file is in the following format:
//    name height width
//    ...
string name;
int height, width;
while (infile >> name >> height >> width) {
sizes_.push_back(std::make_pair(height, width));
}
infile.close();
}
count_ = 0;
// If there is no name_size_file provided, use normalized bbox to evaluate.
use_normalized_bbox_ = sizes_.size() == 0;

// Retrieve resize parameter if there is any provided.
has_resize_ = detection_evaluate_param.has_resize_param();
if (has_resize_) {
resize_param_ = detection_evaluate_param.resize_param();
}
}

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
CHECK_LE(count_, sizes_.size());
CHECK_EQ(bottom[0]->num(), 1);
CHECK_EQ(bottom[0]->channels(), 1);
CHECK_EQ(bottom[0]->width(), 7);
CHECK_EQ(bottom[1]->num(), 1);
CHECK_EQ(bottom[1]->channels(), 1);
CHECK_EQ(bottom[1]->width(), 8);
// num() and channels() are 1.
vector<int> top_shape(2, 1);
int num_pos_classes = background_label_id_ == -1 ?
num_classes_ : num_classes_ - 1;
int num_valid_det = 0;
const Dtype* det_data = bottom[0]->cpu_data();
for (int i = 0; i < bottom[0]->height(); ++i) {
if (det_data[1] != -1) {
++num_valid_det;
}
det_data += 7;
}
top_shape.push_back(num_pos_classes + num_valid_det);
// Each row is a 5 dimension vector, which stores
// [image_id, label, confidence, true_pos, false_pos]
top_shape.push_back(5);
top[0]->Reshape(top_shape);
}

template <typename Dtype>
void DetectionEvaluateLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
const Dtype* det_data = bottom[0]->cpu_data();//  bottom: "detection_out"
const Dtype* gt_data = bottom[1]->cpu_data();  // bottom: "label"

// Retrieve all detection results.
map<int, LabelBBox> all_detections;
GetDetectionResults(det_data, bottom[0]->height(), background_label_id_,// num_det=bottom[0]->height()
&all_detections);
/*
*  .defined in /src/caffe/util/bbox_util.cpp

> void GetDetectionResults(const Dtype* det_data, const int num_det,
>
>       const int background_label_id,
>       map<int, map<int, vector<NormalizedBBox> > >* all_detections) {   all_detections->clear();   for (int i = 0; i < num_det; ++i) {//
>     int start_idx = i * 7;//7
>     /*N : num of det after nms, each row is: [image_id, label, confidence, xmin, ymin, xmax, ymax]*/
>     /*N个将bbox的所有信息存成了一维vector*/
>     int item_id = det_data[start_idx];//0,7...指的是图像ID.
>     if (item_id == -1) {
>       continue;
>     }
>     int label = det_data[start_idx + 1];//每个框的label
>     CHECK_NE(background_label_id, label)//二者相等则输出。。。
>         << "Found background label in the detection results.";
>     NormalizedBBox bbox;
>     bbox.set_score(det_data[start_idx + 2]);
>     bbox.set_xmin(det_data[start_idx + 3]);
>     bbox.set_ymin(det_data[start_idx + 4]);
>     bbox.set_xmax(det_data[start_idx + 5]);
>     bbox.set_ymax(det_data[start_idx + 6]);
>     float bbox_size = BBoxSize(bbox);//box长宽的乘积,加入了边界处理。
>     bbox.set_size(bbox_size);
>     (*all_detections)[item_id][label].push_back(bbox);   } }//

*
*/

// Retrieve all ground truth (including difficult ones).
map<int, LabelBBox> all_gt_bboxes;
GetGroundTruth(gt_data, bottom[1]->height(), background_label_id_,
true, &all_gt_bboxes);

> void GetGroundTruth(const Dtype* gt_data, const int num_gt,
>       const int background_label_id, const bool use_difficult_gt,
>       map<int, vector<NormalizedBBox> >* all_gt_bboxes) {   all_gt_bboxes->clear();
>       /*查看AnnotatedData层如何读取lmdb并分别存储为label和data,label的结构如下 8 个元素*/
>       /*[item_id(图像id), group_label(每一类的id), instance_id(类内), xmin, ymin, xmax, ymax, diff(?)]
>       */
>
>     for (int i = 0; i < num_gt; ++i) {
>     int start_idx = i * 8;
>     int item_id = gt_data[start_idx];
>     if (item_id == -1) {
>       continue;
>     }
>     int label = gt_data[start_idx + 1];
>     CHECK_NE(background_label_id, label)
>         << "Found background label in the dataset.";
>     bool difficult = static_cast<bool>(gt_data[start_idx + 7]);
>     if (!use_difficult_gt && difficult) {
>       // Skip reading difficult ground truth. 哪个bbox的label是difficult的??
>       continue;
>     }
>     NormalizedBBox bbox;
>     bbox.set_label(label);
>     bbox.set_xmin(gt_data[start_idx + 3]);
>     bbox.set_ymin(gt_data[start_idx + 4]);
>     bbox.set_xmax(gt_data[start_idx + 5]);
>     bbox.set_ymax(gt_data[start_idx + 6]);
>     bbox.set_difficult(difficult);
>     float bbox_size = BBoxSize(bbox);//面积
>     bbox.set_size(bbox_size);
>     (*all_gt_bboxes)[item_id].push_back(bbox);   } }

Dtype* top_data = top[0]->mutable_cpu_data();
caffe_set(top[0]->count(), Dtype(0.), top_data);
int num_det = 0;

// Insert number of ground truth for each label.
map<int, int> num_pos;
for (map<int, LabelBBox>::iterator it = all_gt_bboxes.begin();
it != all_gt_bboxes.end(); ++it) {
for (LabelBBox::iterator iit = it->second.begin(); iit != it->second.end();
++iit) {
int count = 0;
if (evaluate_difficult_gt_) {
count = iit->second.size();
} else {
// Get number of non difficult ground truth.
for (int i = 0; i < iit->second.size(); ++i) {
if (!iit->second[i].difficult()) {
++count;
}
}
}
if (num_pos.find(iit->first) == num_pos.end()) {
num_pos[iit->first] = count;
} else {
num_pos[iit->first] += count;
}
}
}
for (int c = 0; c < num_classes_; ++c) {
if (c == background_label_id_) {
continue;
}
top_data[num_det * 5] = -1;
top_data[num_det * 5 + 1] = c;
if (num_pos.find(c) == num_pos.end()) {
top_data[num_det * 5 + 2] = 0;
} else {
top_data[num_det * 5 + 2] = num_pos.find(c)->second;
}
top_data[num_det * 5 + 3] = -1;
top_data[num_det * 5 + 4] = -1;
++num_det;
}

// Insert detection evaluate status.
for (map<int, LabelBBox>::iterator it = all_detections.begin();
it != all_detections.end(); ++it) {
int image_id = it->first;// all_detections  map<int, map<int, vector<NormalizedBBox> > >
LabelBBox& detections = it->second;//  map<int, vector<NormalizedBBox> > 存储每副图像的检测结果。
if (all_gt_bboxes.find(image_id) == all_gt_bboxes.end()) {
// No ground truth for current image. All detections become false_pos.
for (LabelBBox::iterator iit = detections.begin();
iit != detections.end(); ++iit) {
int label = iit->first;
if (label == -1) {
continue;
}
const vector<NormalizedBBox>& bboxes = iit->second;
for (int i = 0; i < bboxes.size(); ++i) {//每一个box
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
top_data[num_det * 5 + 3] = 0;// 0和1是做什么的?
top_data[num_det * 5 + 4] = 1;
++num_det;
}
}
} else {
LabelBBox& label_bboxes = all_gt_bboxes.find(image_id)->second;
for (LabelBBox::iterator iit = detections.begin();
iit != detections.end(); ++iit) {
int label = iit->first;
if (label == -1) {
continue;
}
vector<NormalizedBBox>& bboxes = iit->second;
if (label_bboxes.find(label) == label_bboxes.end()) {
// No ground truth for current label. All detections become false_pos.
for (int i = 0; i < bboxes.size(); ++i) {
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
top_data[num_det * 5 + 3] = 0;// 0,1 false_pos.
top_data[num_det * 5 + 4] = 1;//  1,0 true positive.
++num_det;
}
} else {
vector<NormalizedBBox>& gt_bboxes = label_bboxes.find(label)->second;
// Scale ground truth if needed.
if (!use_normalized_bbox_) {
CHECK_LT(count_, sizes_.size());
for (int i = 0; i < gt_bboxes.size(); ++i) {
OutputBBox(gt_bboxes[i], sizes_[count_], has_resize_,
resize_param_, &(gt_bboxes[i]));
}
}
vector<bool> visited(gt_bboxes.size(), false);
// Sort detections in descend order based on scores.
std::sort(bboxes.begin(), bboxes.end(), SortBBoxDescend);
for (int i = 0; i < bboxes.size(); ++i) {//遍历每一个box。
top_data[num_det * 5] = image_id;
top_data[num_det * 5 + 1] = label;
top_data[num_det * 5 + 2] = bboxes[i].score();
if (!use_normalized_bbox_) {
OutputBBox(bboxes[i], sizes_[count_], has_resize_,
resize_param_, &(bboxes[i]));
}
// Compare with each ground truth bbox.每一个检测出的box遍历匹配图像中每一个gtbox
float overlap_max = -1;
int jmax = -1;
for (int j = 0; j < gt_bboxes.size(); ++j) {
float overlap = JaccardOverlap(bboxes[i], gt_bboxes[j],
use_normalized_bbox_);// 如果没提供 name_size_file,为True
if (overlap > overlap_max) {
overlap_max = overlap;
jmax = j;
}
}
if (overlap_max >= overlap_threshold_) {//overlap_max  :如果有某个gtbox和检测出的boxoverlap>0.5
//overlap_threshold_ 在程序中设置为0.5
if (evaluate_difficult_gt_ ||//这个gtbox不是背景难例的情况下。
(!evaluate_difficult_gt_ && !gt_bboxes[jmax].difficult())) {
if (!visited[jmax]) {//visited初始化为false,表示这个gtbox未访问
// true positive.
top_data[num_det * 5 + 3] = 1;
top_data[num_det * 5 + 4] = 0;
visited[jmax] = true;//访问标记
} else {
// false positive (multiple detection).检测到的bbox已经有和这个gtbox最匹配的了。已访问
top_data[num_det * 5 + 3] = 0;
top_data[num_det * 5 + 4] = 1;
}
}
} else {//当前遍历的这个检测框没有匹配到任何一个gtbox
// false positive.
top_data[num_det * 5 + 3] = 0;
top_data[num_det * 5 + 4] = 1;
}
++num_det;
}
}
}
}
if (sizes_.size() > 0) {
++count_;
if (count_ == sizes_.size()) {
// reset count after a full iterations through the DB.
count_ = 0;
}
}
}
}

INSTANTIATE_CLASS(DetectionEvaluateLayer);
REGISTER_LAYER_CLASS(DetectionEvaluate);

}  // namespace caffe
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  代码 caffe layer