您的位置:首页 > 大数据 > 人工智能

train_cascade 源码阅读之级联训练

2015-01-23 09:26 429 查看
在主函数中,最耀眼的一句话就是这个了:

classifier.train( cascadeDirName,
vecName,
bgName,
numPos, numNeg,
precalcValBufSize, precalcIdxBufSize,
numStages,
cascadeParams,
*featureParams[cascadeParams.featureType],
stageParams,
baseFormatSave );


其实现如下:

bool CvCascadeClassifier::
train(
const string _cascadeDirName,
const string _posFilename,
const string _negFilename,
int _numPos, int _numNeg,
int _precalcValBufSize, int _precalcIdxBufSize,
int _numStages,
const CvCascadeParams& _cascadeParams,
const CvFeatureParams& _featureParams,
const CvCascadeBoostParams& _stageParams,
bool baseFormatSave )
{
// Start recording clock ticks for training time output
const clock_t begin_time = clock();

//确认数据是否有效,略过,下面代码同理,只保留关键语句
……

//判断读入数据级数,并显示
int startNumStages = (int)stageClassifiers.size();
if ( startNumStages > 1 )
cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl;
else if ( startNumStages == 1)
cout << endl << "Stage 0 is loaded" << endl;

//计算要求的叶节点虚警率
double requiredLeafFARate
= pow( (double) stageParams->maxFalseAlarm, (double) numStages )
/(double)stageParams->max_depth;//子树最大深度,默认为1
double tempLeafFARate;

for( int i = startNumStages; i < numStages; i++ )
{
cout << endl << "===== TRAINING " << i << "-stage =====" << endl;
cout << "<BEGIN" << endl;

//无法满足需要的数据,返回
if ( !updateTrainingSet( requiredLeafFARate, tempLeafFARate ) )
{
cout << "Train dataset for temp stage can not be filled. "
"Branch training terminated." << endl;
break;
}
//叶节点虚警率已经达到要求,返回
 if( tempLeafFARate <= requiredLeafFARate )
{
cout << "Required leaf false alarm rate achieved. "
"Branch training terminated." << endl;
break;
}
//开始训练本级
CvCascadeBoost* tempStage = new CvCascadeBoost;
bool isStageTrained = tempStage->train(
(CvFeatureEvaluator*)featureEvaluator,
curNumSamples, _precalcValBufSize, _precalcIdxBufSize,
*((CvCascadeBoostParams*)stageParams) );
cout << "END>" << endl;
//本级训练失败,返回
if(!isStageTrained)
break;
//成功,添加本级
stageClassifiers.push_back( tempStage );

// save params
……

// Output training time up till now
……
}
//上面for循环,break出来的
if(stageClassifiers.size() == 0)
{
cout << "Cascade classifier can't be trained."
" Check the used training parameters." << endl;
return false;
}

//保存级联分类器到xml格式中
save( dirName + CC_CASCADE_FILENAME, baseFormatSave );
return true;
}
接着看updateTrainingSet,每一级操作前先更新样本数据

bool CvCascadeClassifier::updateTrainingSet(
double  minimumAcceptanceRatio,
double  & acceptanceRatio)
{
int64 posConsumed = 0, negConsumed = 0;
imgReader.restart();
//获取正样本
 int posCount = fillPassedSamples( 0, numPos, true, 0, posConsumed );
if( !posCount )
return false;
cout << "POS count : consumed   " << posCount << " : "
<< (int)posConsumed << endl;
//计算需要的负样本 负样本总数乘以 获得的正样本与正样本总数之比,保持了选取训练样本的正负样本比例不变
int proNumNeg = cvRound(
(((double)numNeg) * ((double)posCount) ) / numPos
);
// apply only a fraction of negative samples.
//double is required since overflow is possible
//获取负样本
 int negCount = fillPassedSamples(
posCount,
proNumNeg,
false,
minimumAcceptanceRatio,
negConsumed );
if ( !negCount )
return false;

curNumSamples = posCount + negCount;
//计算acceptanceRatio,也就是FP/(FP+TN)
 acceptanceRatio = negConsumed == 0 ?
0 : ( (double)negCount/(double)(int64)negConsumed );
cout << "NEG count : acceptanceRatio    "
<< negCount << " : " << acceptanceRatio << endl;
return true;
}
上段代码中最重要的就是fillPassedSamples函数了。

int CvCascadeClassifier::fillPassedSamples(
int     first,
int     count,
bool    isPositive,
double  minimumAcceptanceRatio,
int64   &consumed )
{
int getcount = 0;
Mat img(cascadeParams.winSize, CV_8UC1);
for( int i = first; i < first + count; i++ )
{
for( ; ; )
{
if( consumed != 0
&& ((double)getcount+1)/(double)(int64)consumed
<= minimumAcceptanceRatio )
return getcount;
//获取对应类别图片
bool isGetImg = isPositive ? imgReader.getPos( img ) :
imgReader.getNeg( img );
if( !isGetImg )
return getcount;
consumed++;
//在数据矩阵中设置图像类别
featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
//如果预测为正样本就跳出循环,在填充负样本的过程中,返回的也是误判为正样本的值。
 if( predict( i ) == 1.0F )
{
getcount++;
printf("%s current samples: %d\r",
isPositive ? "POS":"NEG", getcount);
break;
}
}
}
return getcount;
}


实际的预测过程,值是每个弱分类器预测值的和。

float CvCascadeBoost::predict( int sampleIdx, bool returnSum ) const
{
CV_Assert( weak );
double sum = 0;
CvSeqReader reader;
cvStartReadSeq( weak, &reader );
cvSetSeqReaderPos( &reader, 0 );
for( int i = 0; i < weak->total; i++ )
{
CvBoostTree* wtree;
CV_READ_SEQ_ELEM( wtree, reader );
sum += ((CvCascadeBoostTree*)wtree)->predict(sampleIdx)->value;
}
if( !returnSum )
sum = sum < threshold - CV_THRESHOLD_EPS ? 0.0 : 1.0;
return (float)sum;
}
到了训练部分:

bool CvCascadeBoost::
train(
const CvFeatureEvaluator* _featureEvaluator,
int _numSamples,
int _precalcValBufSize, int _precalcIdxBufSize,
const CvCascadeBoostParams& _params )
{
bool isTrained = false;
CV_Assert( !data );
clear();
data = new CvCascadeBoostTrainData(
_featureEvaluator, _numSamples,
_precalcValBufSize, _precalcIdxBufSize, _params );
CvMemStorage *storage = cvCreateMemStorage();
weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
storage = 0;

set_params( _params );
if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) )
data->do_responses_copy();
//初始化权值
update_weights( 0 );

cout << "+----+---------+---------+" << endl;
cout << "|  N |    HR   |    FA   |" << endl;
cout << "+----+---------+---------+" << endl;

do
{
//训练树
 CvCascadeBoostTree* tree = new CvCascadeBoostTree;
if( !tree->train( data, subsample_mask, this ) )
{
delete tree;
break;
}
cvSeqPush( weak, &tree );
//更新权值
 update_weights( tree );
trim_weights();
if( cvCountNonZero(subsample_mask) == 0 )
break;
}
while( !isErrDesired() && (weak->total < params.weak_count) );
//循环终止条件,虚警率达到要求或者达到最大弱分类器数目

if(weak->total > 0)
{
data->is_classifier = true;
data->free_train_data();
isTrained = true;
}
else
clear();

return isTrained;
}


这样,一级就训练完了。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: