您的位置:首页 > 编程语言

SSD MultiBoxLossLayer代码学习记录

2016-08-15 16:10 417 查看
template <typename Dtype>
void MultiBoxLossLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::LayerSetUp(bottom, top);
if (this->layer_param_.propagate_down_size() == 0) {
this->layer_param_.add_propagate_down(true);//location prediction
this->layer_param_.add_propagate_down(true);//confidence prediction
this->layer_param_.add_propagate_down(false);//prior
this->layer_param_.add_propagate_down(false);//ground truth
}
const MultiBoxLossParameter& multibox_loss_param =
this->layer_param_.multibox_loss_param();

num_ = bottom[0]->num();//这个是batchsize
num_priors_ = bottom[2]->height() / 4;//这个是先验的个数,每个先验包含左上角和右下角的点坐标。
// Get other parameters.
CHECK(multibox_loss_param.has_num_classes()) << "Must provide num_classes.";//类别个数一定要提供
num_classes_ = multibox_loss_param.num_classes();//类别个数
CHECK_GE(num_classes_, 1) << "num_classes should not be less than 1.";
share_location_ = multibox_loss_param.share_location();//共享类别位置预测
loc_classes_ = share_location_ ? 1 : num_classes_;//如果shared表示所有的类别同用一个location prediction,否则每一类各自预测。
match_type_ = multibox_loss_param.match_type();
overlap_threshold_ = multibox_loss_param.overlap_threshold();
use_prior_for_matching_ = multibox_loss_param.use_prior_for_matching();//是否使用先验进行匹配
background_label_id_ = multibox_loss_param.background_label_id();//background的id
use_difficult_gt_ = multibox_loss_param.use_difficult_gt();//是否使用difficutlt的ground truth
do_neg_mining_ = multibox_loss_param.do_neg_mining();//ming 负样本
neg_pos_ratio_ = multibox_loss_param.neg_pos_ratio();//负样本与正样本的比例
neg_overlap_ = multibox_loss_param.neg_overlap();//负样本overlap的阈值
code_type_ = multibox_loss_param.code_type();//编码方式(location)
encode_variance_in_target_ = multibox_loss_param.encode_variance_in_target();//默认是false
map_object_to_agnostic_ = multibox_loss_param.map_object_to_agnostic();//默认是false
if (map_object_to_agnostic_) {
if (background_label_id_ >= 0) {
CHECK_EQ(num_classes_, 2);
} else {
CHECK_EQ(num_classes_, 1);
}
}

if (!this->layer_param_.loss_param().has_normalization() &&
this->layer_param_.loss_param().has_normalize()) {
normalization_ = this->layer_param_.loss_param().normalize() ?
LossParameter_NormalizationMode_VALID :
LossParameter_NormalizationMode_BATCH_SIZE;
} else {
normalization_ = this->layer_param_.loss_param().normalization();
}

if (do_neg_mining_) {
CHECK(share_location_)
<< "Currently only support negative mining if share_location is true.";
CHECK_GT(neg_pos_ratio_, 0);
}

vector<int> loss_shape(1, 1);
// Set up localization loss layer.
loc_weight_ = multibox_loss_param.loc_weight();//location loss weight
loc_loss_type_ = multibox_loss_param.loc_loss_type();//loss的类型
// fake shape.
vector<int> loc_shape(1, 1);
loc_shape.push_back(4);
loc_pred_.Reshape(loc_shape);
loc_gt_.Reshape(loc_shape);
loc_bottom_vec_.push_back(&loc_pred_);//存放前面的指针
loc_bottom_vec_.push_back(&loc_gt_);//存放前面的指针
loc_loss_.Reshape(loss_shape);//location的loss
loc_top_vec_.push_back(&loc_loss_);//存放top的指针
//新建一个层,实现对locationloss的计算。
if (loc_loss_type_ == MultiBoxLossParameter_LocLossType_L2) {
LayerParameter layer_param;
layer_param.set_name(this->layer_param_.name() + "_l2_loc");
layer_param.set_type("EuclideanLoss");
layer_param.add_loss_weight(loc_weight_);
loc_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
loc_loss_layer_->SetUp(loc_bottom_vec_, loc_top_vec_);
} else if (loc_loss_type_ == MultiBoxLossParameter_LocLossType_SMOOTH_L1) {
LayerParameter layer_param;
layer_param.set_name(this->layer_param_.name() + "_smooth_L1_loc");
layer_param.set_type("SmoothL1Loss");
layer_param.add_loss_weight(loc_weight_);
loc_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
loc_loss_layer_->SetUp(loc_bottom_vec_, loc_top_vec_);
} else {
LOG(FATAL) << "Unknown localization loss type.";
}
// Set up confidence loss layer.
//新建一个层,实现的是对confidence loss的计算。
conf_loss_type_ = multibox_loss_param.conf_loss_type();
conf_bottom_vec_.push_back(&conf_pred_);
conf_bottom_vec_.push_back(&conf_gt_);
conf_loss_.Reshape(loss_shape);
conf_top_vec_.push_back(&conf_loss_);
if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_SOFTMAX) {
LayerParameter layer_param;
layer_param.set_name(this->layer_param_.name() + "_softmax_conf");
layer_param.set_type("SoftmaxWithLoss");
layer_param.add_loss_weight(Dtype(1.));
layer_param.mutable_loss_param()->set_normalization(
LossParameter_NormalizationMode_NONE);
SoftmaxParameter* softmax_param = layer_param.mutable_softmax_param();
softmax_param->set_axis(1);
// Fake reshape.
vector<int> conf_shape(1, 1);
conf_gt_.Reshape(conf_shape);
conf_shape.push_back(num_classes_);
conf_pred_.Reshape(conf_shape);
conf_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
conf_loss_layer_->SetUp(conf_bottom_vec_, conf_top_vec_);
} else if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_LOGISTIC) {
LayerParameter layer_param;
layer_param.set_name(this->layer_param_.name() + "_logistic_conf");
layer_param.set_type("SigmoidCrossEntropyLoss");
layer_param.add_loss_weight(Dtype(1.));
// Fake reshape.
vector<int> conf_shape(1, 1);
conf_shape.push_back(num_classes_);
conf_gt_.Reshape(conf_shape);
conf_pred_.Reshape(conf_shape);
conf_loss_layer_ = LayerRegistry<Dtype>::CreateLayer(layer_param);
conf_loss_layer_->SetUp(conf_bottom_vec_, conf_top_vec_);
} else {
LOG(FATAL) << "Unknown confidence loss type.";
}
}

template <typename Dtype>
void MultiBoxLossLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
LossLayer<Dtype>::Reshape(bottom, top);
num_ = bottom[0]->num();
num_priors_ = bottom[2]->height() / 4;
num_gt_ = bottom[3]->height();
CHECK_EQ(bottom[0]->num(), bottom[1]->num());
//loc_classes共享是1,不共享就是classes数
CHECK_EQ(num_priors_ * loc_classes_ * 4, bottom[0]->channels())
<< "Number of priors must match number of location predictions.";
CHECK_EQ(num_priors_ * num_classes_, bottom[1]->channels())
<< "Number of priors must match number of confidence predictions.";
}
//预测loction bottom[0] dimension is [N*C*1*1],confidence bottom[1] dimension is [N*C*1*1]
//priors bottom[2] dimension is [N*1*2*W], gound truth bottom[3] dimension is [N*1*H*8]
template <typename Dtype>
void MultiBoxLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* loc_data = bottom[0]->cpu_data();
const Dtype* conf_data = bottom[1]->cpu_data();
const Dtype* prior_data = bottom[2]->cpu_data();
const Dtype* gt_data = bottom[3]->cpu_data();

// Retrieve all ground truth.
/*
message NormalizedBBox {
optional float xmin = 1;
optional float ymin = 2;
optional float xmax = 3;
optional float ymax = 4;
optional int32 label = 5;
optional bool difficult = 6;
optional float score = 7;
optional float size = 8;
}
*/
map<int, vector<NormalizedBBox> > all_gt_bboxes;//转化ground truth bounding box,存放在all_gt_bboxes
GetGroundTruth(gt_data, num_gt_, background_label_id_, use_difficult_gt_,
&all_gt_bboxes);

// Retrieve all prior bboxes. It is same within a batch since we assume all
// images in a batch are of same dimension.
vector<NormalizedBBox> prior_bboxes;
vector<vector<float> > prior_variances;
GetPriorBBoxes(prior_data, num_priors_, &prior_bboxes, &prior_variances);

// Retrieve all predictions.
vector<LabelBBox> all_loc_preds;
GetLocPredictions(loc_data, num_, num_priors_, loc_classes_, share_location_,
&all_loc_preds);

// Retrieve max scores for each prior. Used in negative mining.
vector<vector<float> > all_max_scores;//获取每个样本里面每个box的最大置信值。
if (do_neg_mining_) {
GetMaxConfidenceScores(conf_data, num_, num_priors_, num_classes_,
background_label_id_, conf_loss_type_,
&all_max_scores);
}

num_matches_ = 0;
int num_negs = 0;
for (int i = 0; i < num_; ++i) {
map<int, vector<int> > match_indices;
vector<int> neg_indices;
// Check if there is ground truth for current image.
if (all_gt_bboxes.find(i) == all_gt_bboxes.end()) {//如果图片不存在ground truth
// There is no gt for current image. All predictions are negative.
all_match_indices_.push_back(match_indices);
all_neg_indices_.push_back(neg_indices);
continue;
}
// Find match between predictions and ground truth.
const vector<NormalizedBBox>& gt_bboxes = all_gt_bboxes.find(i)->second;
map<int, vector<float> > match_overlaps;
if (!use_prior_for_matching_) {//不使用先验
for (int c = 0; c < loc_classes_; ++c) {
int label = share_location_ ? -1 : c;
if (!share_location_ && label == background_label_id_) {
// Ignore background loc predictions.
continue;
}
// Decode the prediction into bbox first.
vector<NormalizedBBox> loc_bboxes;
DecodeBBoxes(prior_bboxes, prior_variances,
code_type_, encode_variance_in_target_,
all_loc_preds[i][label], &loc_bboxes);//解码bbox
MatchBBox(gt_bboxes, loc_bboxes, label, match_type_,
overlap_threshold_, &match_indices[label],
&match_overlaps[label]);//计算
}
} else {//使用先验
// Use prior bboxes to match against all ground truth.
vector<int> temp_match_indices;//存放的规则是第i个prior与temp_match_indices[i]grounding truth 匹配,要是为-1表示不匹配。
vector<float> temp_match_overlaps;//存放上面说的overlap
const int label = -1;
MatchBBox(gt_bboxes, prior_bboxes, label, match_type_, overlap_threshold_,
&temp_match_indices, &temp_match_overlaps);
if (share_location_) {//正常情况下是share loaction。
match_indices[label] = temp_match_indices;
match_overlaps[label] = temp_match_overlaps;
} else {
// Get ground truth label for each ground truth bbox.
vector<int> gt_labels;
for (int g = 0; g < gt_bboxes.size(); ++g) {
gt_labels.push_back(gt_bboxes[g].label());
}
// Distribute the matching results to different loc_class.
for (int c = 0; c < loc_classes_; ++c) {
if (c == background_label_id_) {
// Ignore background loc predictions.
continue;
}//if c
match_indices[c].resize(temp_match_indices.size(), -1);
match_overlaps[c] = temp_match_overlaps;
for (int m = 0; m < temp_match_indices.size(); ++m) {
if (temp_match_indices[m] != -1) {
const int gt_idx = temp_match_indices[m];
CHECK_LT(gt_idx, gt_labels.size());
if (c == gt_labels[gt_idx]) {
match_indices[c][m] = gt_idx;
}//if c
}//if temp_match_indices
}//for m
}// for c
}
}//使用先验 计算结束
// Record matching statistics.
for (map<int, vector<int> >::iterator it = match_indices.begin();
it != match_indices.end(); ++it) {
const int label = it->first;
// Get positive indices.
int num_pos = 0;
for (int m = 0; m < match_indices[label].size(); ++m) {
if (match_indices[label][m] != -1) {
++num_pos;
}
}
num_matches_ += num_pos;
if (do_neg_mining_) {
// Get max scores for all the non-matched priors.
vector<pair<float, int> > scores_indices;
int num_neg = 0;
for (int m = 0; m < match_indices[label].size(); ++m) {
if (match_indices[label][m] == -1 &&
match_overlaps[label][m] < neg_overlap_) {
scores_indices.push_back(std::make_pair(all_max_scores[i][m], m));
++num_neg;
}
}
// Pick top num_neg negatives.
num_neg = std::min(static_cast<int>(num_pos * neg_pos_ratio_), num_neg);
std::sort(scores_indices.begin(), scores_indices.end(),
SortScorePairDescend<int>);//排序,得到前面num_neg匹配。
for (int n = 0; n < num_neg; ++n) {
neg_indices.push_back(scores_indices
.second);
}
num_negs += num_neg;
}
}
all_match_indices_.push_back(match_indices);
all_neg_indices_.push_back(neg_indices);
}

if (num_matches_ >= 1) {
// Form data to pass on to loc_loss_layer_.
vector<int> loc_shape(2);
loc_shape[0] = 1;
loc_shape[1] = num_matches_ * 4;
loc_pred_.Reshape(loc_shape);//地址已经存放进了loc_bottom_vec_
loc_gt_.Reshape(loc_shape);//地址已经存放进了loc_bottom_vec_
Dtype* loc_pred_data = loc_pred_.mutable_cpu_data();
Dtype* loc_gt_data = loc_gt_.mutable_cpu_data();
int count = 0;
for (int i = 0; i < num_; ++i) {
for (map<int, vector<int> >::iterator it = all_match_indices_[i].begin();
it != all_match_indices_[i].end(); ++it) {
const int label = it->first;
const vector<int>& match_index = it->second;
CHECK(all_loc_preds[i].find(label) != all_loc_preds[i].end());
const vector<NormalizedBBox>& loc_pred = all_loc_preds[i][label];
for (int j = 0; j < match_index.size(); ++j) {
if (match_index[j] == -1) {
continue;
}
// Store location prediction.
CHECK_LT(j, loc_pred.size());
loc_pred_data[count * 4] = loc_pred[j].xmin();
loc_pred_data[count * 4 + 1] = loc_pred[j].ymin();
loc_pred_data[count * 4 + 2] = loc_pred[j].xmax();
loc_pred_data[count * 4 + 3] = loc_pred[j].ymax();
// Store encoded ground truth.
const int gt_idx = match_index[j];
CHECK(all_gt_bboxes.find(i) != all_gt_bboxes.end());
CHECK_LT(gt_idx, all_gt_bboxes[i].size());
const NormalizedBBox& gt_bbox = all_gt_bboxes[i][gt_idx];
NormalizedBBox gt_encode;
CHECK_LT(j, prior_bboxes.size());
EncodeBBox(prior_bboxes[j], prior_variances[j], code_type_,
encode_variance_in_target_, gt_bbox, >_encode);
loc_gt_data[count * 4] = gt_encode.xmin();
loc_gt_data[count * 4 + 1] = gt_encode.ymin();
loc_gt_data[count * 4 + 2] = gt_encode.xmax();
loc_gt_data[count * 4 + 3] = gt_encode.ymax();
if (encode_variance_in_target_) {//正常情况下不使用。
for (int k = 0; k < 4; ++k) {
CHECK_GT(prior_variances[j][k], 0);
loc_pred_data[count * 4 + k] /= prior_variances[j][k];
loc_gt_data[count * 4 + k] /= prior_variances[j][k];
}
}
++count;
}
}
}
loc_loss_layer_->Reshape(loc_bottom_vec_, loc_top_vec_);
loc_loss_layer_->Forward(loc_bottom_vec_, loc_top_vec_);
}

// Form data to pass on to conf_loss_layer_.
if (do_neg_mining_) {//计算positive和negative样本。
num_conf_ = num_matches_ + num_negs;
} else {
num_conf_ = num_ * num_priors_;
}
if (num_conf_ >= 1) {
// Reshape the confidence data.
vector<int> conf_shape;
if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_SOFTMAX) {
conf_shape.push_back(num_conf_);
conf_gt_.Reshape(conf_shape);
conf_shape.push_back(num_classes_);
conf_pred_.Reshape(conf_shape);
} else if (conf_loss_type_ == MultiBoxLossParameter_ConfLossType_LOGISTIC) {
conf_shape.push_back(1);
conf_shape.push_back(num_conf_);
conf_shape.push_back(num_classes_);
conf_gt_.Reshape(conf_shape);
conf_pred_.Reshape(conf_shape);
} else {
LOG(FATAL) << "Unknown confidence loss type.";
}
if (!do_neg_mining_) {
// Consider all scores.
// Share data and diff with bottom[1].
CHECK_EQ(conf_pred_.count(), bottom[1]->count());
conf_pred_.ShareData(*(bottom[1]));
}
Dtype* conf_pred_data = conf_pred_.mutable_cpu_data();
Dtype* conf_gt_data = conf_gt_.mutable_cpu_data();
caffe_set(conf_gt_.count(), Dtype(background_label_id_), conf_gt_data);
int count = 0;
for (int i = 0; i < num_; ++i) {
if (all_gt_bboxes.find(i) != all_gt_bboxes.end()) {
// Save matched (positive) bboxes scores and labels.
const map<int, vector<int> >& match_indices = all_match_indices_[i];
for (int j = 0; j < num_priors_; ++j) {
for (map<int, vector<int> >::const_iterator it =
match_indices.begin(); it != match_indices.end(); ++it) {
const vector<int>& match_index = it->second;
CHECK_EQ(match_index.size(), num_priors_);
if (match_index[j] == -1) {
continue;
}
const int gt_label = map_object_to_agnostic_ ?
background_label_id_ + 1 :
all_gt_bboxes[i][match_index[j]].label();
int idx = do_neg_mining_ ? count : j;
switch (conf_loss_type_) {
case MultiBoxLossParameter_ConfLossType_SOFTMAX:
conf_gt_data[idx] = gt_label;
break;
case MultiBoxLossParameter_ConfLossType_LOGISTIC:
conf_gt_data[idx * num_classes_ + gt_label] = 1;
break;
default:
LOG(FATAL) << "Unknown conf loss type.";
}
if (do_neg_mining_) {
// Copy scores for matched bboxes.
caffe_copy<Dtype>(num_classes_, conf_data + j * num_classes_,
conf_pred_data + count * num_classes_);
++count;
}
}
}
if (do_neg_mining_) {
// Save negative bboxes scores and labels.
for (int n = 0; n < all_neg_indices_[i].size(); ++n) {
int j = all_neg_indices_[i]
;
CHECK_LT(j, num_priors_);
caffe_copy<Dtype>(num_classes_, conf_data + j * num_classes_,
conf_pred_data + count * num_classes_);
switch (conf_loss_type_) {
case MultiBoxLossParameter_ConfLossType_SOFTMAX:
conf_gt_data[count] = background_label_id_;
break;
case MultiBoxLossParameter_ConfLossType_LOGISTIC:
conf_gt_data[count * num_classes_ + background_label_id_] = 1;
break;
default:
LOG(FATAL) << "Unknown conf loss type.";
}
++count;
}
}
}
// Go to next image.
if (do_neg_mining_) {
conf_data += bottom[1]->offset(1);
} else {
conf_gt_data += num_priors_;
}
}
conf_loss_layer_->Reshape(conf_bottom_vec_, conf_top_vec_);
conf_loss_layer_->Forward(conf_bottom_vec_, conf_top_vec_);
}

top[0]->mutable_cpu_data()[0] = 0;
if (this->layer_param_.propagate_down(0)) {
// TODO(weiliu89): Understand why it needs to divide 2.
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, num_, num_priors_, num_matches_);
top[0]->mutable_cpu_data()[0] +=
loc_weight_ * loc_loss_.cpu_data()[0] / normalizer;
}
if (this->layer_param_.propagate_down(1)) {
// TODO(weiliu89): Understand why it needs to divide 2.
Dtype normalizer = LossLayer<Dtype>::GetNormalizer(
normalization_, num_, num_priors_, num_matches_);
top[0]->mutable_cpu_data()[0] += conf_loss_.cpu_data()[0] / normalizer;
}
}


void MatchBBox(const vector<NormalizedBBox>& gt_bboxes,
const vector<NormalizedBBox>& pred_bboxes, const int label,
const MatchType match_type, const float overlap_threshold,
vector<int>* match_indices, vector<float>* match_overlaps) {
int num_pred = pred_bboxes.size();
match_indices->clear();
match_indices->resize(num_pred, -1);//-1表示还没有进行匹配
match_overlaps->clear();
match_overlaps->resize(num_pred, 0.);

int num_gt = 0;//ground truth的个数。
vector<int> gt_indices;
//label是-1表示对比所有的ground truth,label不是-1表示只比较label类型的ground truth
if (label == -1) {
// label -1 means comparing against all ground truth.
num_gt = gt_bboxes.size();
for (int i = 0; i < num_gt; ++i) {
gt_indices.push_back(i);
}
} else {
// Count number of ground truth boxes which has the desired label.
for (int i = 0; i < gt_bboxes.size(); ++i) {
if (gt_bboxes[i].label() == label) {
num_gt++;
gt_indices.push_back(i);
}
}
}
if (num_gt == 0) {
return;
}

// Store the positive overlap between predictions and ground truth.
//match_overlaps存放的是第i个pred bounding box和grounding truth最大的overlap值。
map<int, map<int, float> > overlaps;//存放overlap值。
for (int i = 0; i < num_pred; ++i) {
for (int j = 0; j < num_gt; ++j) {
float overlap = JaccardOverlap(pred_bboxes[i], gt_bboxes[gt_indices[j]]);
if (overlap > 1e-6) {//如果为零就不保存
(*match_overlaps)[i] = std::max((*match_overlaps)[i], overlap);
overlaps[i][j] = overlap;
}
}
}

// Bipartite matching.双向匹配??
vector<int> gt_pool;//grounding turth 池
for (int i = 0; i < num_gt; ++i) {
gt_pool.push_back(i);
}
while (gt_pool.size() > 0) {
// Find the most overlapped gt and cooresponding predictions.
int max_idx = -1;
int max_gt_idx = -1;
float max_overlap = -1;
for (map<int, map<int, float> >::iterator it = overlaps.begin();
it != overlaps.end(); ++it) {
int i = it->first;//i表示第i个pred bounding box
if ((*match_indices)[i] != -1) {
// The prediction already have matched ground truth.
continue;
}//if
for (int p = 0; p < gt_pool.size(); ++p) {
int j = gt_pool[p];
if (it->second.find(j) == it->second.end()) {//第i个pred 与第j个grounding truth没有overlap
// No overlap between the i-th prediction and j-th ground truth.
continue;
}//if
// Find the maximum overlapped pair.
if (it->second[j] > max_overlap) {
// If the prediction has not been matched to any ground truth,
// and the overlap is larger than maximum overlap, update.
max_idx = i;
max_gt_idx = j;
max_overlap = it->second[j];
}//if
}//for int p = 0;
}//for map<int, map<int, float> >
if (max_idx == -1) {
// Cannot find good match.
break;
} else {
CHECK_EQ((*match_indices)[max_idx], -1);
(*match_indices)[max_idx] = gt_indices[max_gt_idx];
(*match_overlaps)[max_idx] = max_overlap;
// Erase the ground truth.
gt_pool.erase(std::find(gt_pool.begin(), gt_pool.end(), max_gt_idx));
}
}

switch (match_type) {
case MultiBoxLossParameter_MatchType_BIPARTITE:
// Already done.
break;
case MultiBoxLossParameter_MatchType_PER_PREDICTION:
// Get most overlaped for the rest prediction bboxes.
for (map<int, map<int, float> >::iterator it = overlaps.begin();
it != overlaps.end(); ++it) {
int i = it->first;
if ((*match_indices)[i] != -1) {
// The prediction already have matched ground truth.
continue;
}
int max_gt_idx = -1;
float max_overlap = -1;
for (int j = 0; j < num_gt; ++j) {
if (it->second.find(j) == it->second.end()) {
// No overlap between the i-th prediction and j-th ground truth.
continue;
}
// Find the maximum overlapped pair.
float overlap = it->second[j];
if (overlap >= overlap_threshold && overlap > max_overlap) {
// If the prediction has not been matched to any ground truth,
// and the overlap is larger than maximum overlap, update.
max_gt_idx = j;
max_overlap = overlap;
}
}
if (max_gt_idx != -1) {
// Found a matched ground truth.
CHECK_EQ((*match_indices)[i], -1);
(*match_indices)[i] = gt_indices[max_gt_idx];
(*match_overlaps)[i] = max_overlap;
}
}
break;
default:
LOG(FATAL) << "Unknown matching type.";
break;
}

return;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: