您的位置:首页 > 运维架构

OpenCV中随机森林的实现与字符识别例子

2014-05-13 15:56 573 查看
之前一篇文章简单介绍了随机森林,并且给出来了一些随机森林的资源:http://blog.csdn.net/holybin/article/details/25653597

在opencv中随机森林的实现为CvRTrees类,version2.0及以下版本定义于\OpenCV2.0\include\opencv\ml.h,version2.0以上版本定义于\OpenCV2.4.0\modules\ml\include\opencv2\ml\ml.hpp,实现在\OpenCV2.4.0\modules\ml\src\rtress.cpp,版本不同实现略有不同。

以opencv2.4.0为例子,其定义如下:

class CV_EXPORTS_W CvRTrees : public CvStatModel
{
public:
CV_WRAP CvRTrees();
virtual ~CvRTrees();
virtual bool train( const CvMat* trainData, int tflag,
const CvMat* responses, const CvMat* varIdx=0,
const CvMat* sampleIdx=0, const CvMat* varType=0,
const CvMat* missingDataMask=0,
CvRTParams params=CvRTParams() );

virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;

#ifndef SWIG
CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
const cv::Mat& missingDataMask=cv::Mat(),
CvRTParams params=CvRTParams() );
CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
CV_WRAP virtual cv::Mat getVarImportance();
#endif

CV_WRAP virtual void clear();

virtual const CvMat* get_var_importance();
virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;

virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}

virtual float get_train_error();

virtual void read( CvFileStorage* fs, CvFileNode* node );
virtual void write( CvFileStorage* fs, const char* name ) const;

CvMat* get_active_var_mask();
CvRNG* get_rng();

int get_tree_count() const;
CvForestTree* get_tree(int i) const;

protected:
virtual std::string getName() const;

virtual bool grow_forest( const CvTermCriteria term_crit );

// array of the trees of the forest
CvForestTree** trees;
CvDTreeTrainData* data;
int ntrees;
int nclasses;
double oob_error;
CvMat* var_importance;
int nsamples;

cv::RNG* rng;
CvMat* active_var_mask;
};


这里的CvStatModel是OpenCV的机器学习模块(The Machine Learning Library,MLL)的基类,包括KNN,Bayes,SVM等诸多实现都是基于该类,参考opencv的文档:docs.opencv.org/modules/ml/doc/ml.html,附个SVM的使用例子:docs.opencv.org/doc/tutorials/ml/non_linear_svms/non_linear_svms.html#nonlinearsvms。

以下使用CvRTrees类来对字符数据作分类,该例子即opencv附带的例子“\OpenCV2.4.0\samples\cpp\letter_recog.cpp”,字符数据“\OpenCV2.4.0\samples\cpp\letter-recognition.data”来源于UCI,还有一个csv格式的,这个网站还有很多很好的机器学习数据库。
在本例子中,字符数据“letter-recognition.data”有20000个训练字母,每一字母用16维的特征表示:
 1.lettrcapital
letter(26 values from A to Z)
 2.x-boxhorizontal
position of box(integer)
 3.y-boxvertical
position of box(integer)
 4.widthwidth
of box(integer)
 5.high height
of box(integer)
 6.onpixtotal
# on pixels(integer)
 7.x-barmean
x of on pixels in box(integer)
 8.y-barmean
y of on pixels in box(integer)
 9.x2barmean
x variance(integer)
10.y2barmean
y variance(integer)
11.xybarmean
x y correlation(integer)
12.x2ybrmean
of x * x * y(integer)
13.xy2brmean
of x * y * y(integer)
14.x-egemean
edge count left to right(integer)
15.xegvycorrelation
of x-ege with y(integer)
16.y-egemean
edge count bottom to top(integer)
17.yegvxcorrelation
of y-ege with x(integer)
程序中使用前16000个进行训练,后4000个进行测试。

#include "opencv2/core/core_c.h"
#include "opencv2/ml/ml.hpp"

#include <cstdio>
#include <vector>
/*
Modified from F:\Program Files\OpenCV2.4.0\samples\cpp\letter_recog.cpp
Only RF method reserved.
*/

void help()
{
printf("\nThe sample demonstrates how to train Random Trees classifier\n"
"(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n"
"\n"
"We use the sample database letter-recognition.data\n"
"from UCI Repository, here is the link:\n"
"\n"
"Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
"UCI Repository of machine learning databases\n"
"[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
"Irvine, CA: University of California, Department of Information and Computer Science.\n"
"\n"
"The dataset consists of 20000 feature vectors along with the\n"
"responses - capital latin letters A..Z.\n"
"The first 16000 (10000 for boosting)) samples are used for training\n"
"and the remaining 4000 (10000 for boosting) - to test the classifier.\n"
"======================================================\n");
printf("\nThis is letter recognition sample.\n"
"The usage: letter_recog [-data <path to letter-recognition.data>] \\\n"
"  [-save <output XML file for the classifier>] \\\n"
"  [-load <XML file with the pre-trained classifier>] \\\n"
//"  [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n"
);
}

// This function reads data and responses from the file <filename>
static int read_num_class_data( const char* filename, int var_count, CvMat** data, CvMat** responses )
{
const int M = 1024;
FILE* f = fopen( filename, "rt" );
CvMemStorage* storage;
CvSeq* seq;
char buf[M+2];
float* el_ptr;
CvSeqReader reader;
int i, j;

if( !f )
return 0;

el_ptr = new float[var_count+1];
storage = cvCreateMemStorage();
seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );

for(;;)
{
char* ptr;
if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
break;
el_ptr[0] = buf[0];
ptr = buf+2;
for( i = 1; i <= var_count; i++ )
{
int n = 0;
sscanf( ptr, "%f%n", el_ptr + i, &n );
ptr += n + 1;
}
if( i <= var_count )
break;
cvSeqPush( seq, el_ptr );
}
fclose(f);

*data = cvCreateMat( seq->total, var_count, CV_32F );
*responses = cvCreateMat( seq->total, 1, CV_32F );

cvStartReadSeq( seq, &reader );

for( i = 0; i < seq->total; i++ )
{
const float* sdata = (float*)reader.ptr + 1;
float* ddata = data[0]->data.fl + var_count*i;
float* dr = responses[0]->data.fl + i;

for( j = 0; j < var_count; j++ )
ddata[j] = sdata[j];
*dr = sdata[-1];
CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
}

cvReleaseMemStorage( &storage );
delete[] el_ptr;
return 1;
}

static int build_rtrees_classifier( char* data_filename,	char* filename_to_save, char* filename_to_load )
{
CvMat* data = 0;
CvMat* responses = 0;
CvMat* var_type = 0;
CvMat* sample_idx = 0;

int ok = read_num_class_data( data_filename, 16, &data, &responses );
int nsamples_all = 0, ntrain_samples = 0;
int i = 0;
double train_hr = 0, test_hr = 0;
CvRTrees forest;
CvMat* var_importance = 0;

if( !ok )
{
printf( "Could not read the database %s\n", data_filename );
return -1;
}

printf( "The database %s is loaded.\n", data_filename );
nsamples_all = data->rows;
ntrain_samples = (int)(nsamples_all*0.8);

// Create or load Random Trees classifier
if( filename_to_load )
{
// load classifier from the specified file
forest.load( filename_to_load );
ntrain_samples = 0;
if( forest.get_tree_count() == 0 )
{
printf( "Could not read the classifier %s\n", filename_to_load );
return -1;
}
printf( "The classifier %s is loaded.\n", data_filename );
}
else
{
// create classifier by using <data> and <responses>
printf( "Training the classifier ...\n");

// 1. create type mask
var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) );
cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL );

// 2. create sample_idx
sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 );
{
CvMat mat;
cvGetCols( sample_idx, &mat, 0, ntrain_samples );
cvSet( &mat, cvRealScalar(1) );

cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all );
cvSetZero( &mat );
}

// 3. train classifier
forest.train( data, CV_ROW_SAMPLE, responses, 0, sample_idx, var_type, 0,
CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
printf( "\n");
}

// compute prediction error on train and test data
for( i = 0; i < nsamples_all; i++ )
{
double r;
CvMat sample;
cvGetRow( data, &sample, i );

r = forest.predict( &sample );
r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0;

if( i < ntrain_samples )
train_hr += r;
else
test_hr += r;
}

test_hr /= (double)(nsamples_all-ntrain_samples);
train_hr /= (double)ntrain_samples;
printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n",
train_hr*100., test_hr*100. );

printf( "Number of trees: %d\n", forest.get_tree_count() );

// Print variable importance
var_importance = (CvMat*)forest.get_var_importance();
if( var_importance )
{
double rt_imp_sum = cvSum( var_importance ).val[0];
printf("var#\timportance (in %%):\n");
for( i = 0; i < var_importance->cols; i++ )
printf( "%-2d\t%-4.1f\n", i,
100.f*var_importance->data.fl[i]/rt_imp_sum);
}

//Print some proximitites
printf( "Proximities between some samples corresponding to the letter 'T':\n" );
{
CvMat sample1, sample2;
const int pairs[][2] = {{0,103}, {0,106}, {106,103}, {-1,-1}};

for( i = 0; pairs[i][0] >= 0; i++ )
{
cvGetRow( data, &sample1, pairs[i][0] );
cvGetRow( data, &sample2, pairs[i][1] );
printf( "proximity(%d,%d) = %.1f%%\n", pairs[i][0], pairs[i][1],
forest.get_proximity( &sample1, &sample2 )*100. );
}
}

// Save Random Trees classifier to file if needed
if( filename_to_save )
forest.save( filename_to_save );

cvReleaseMat( &sample_idx );
cvReleaseMat( &var_type );
cvReleaseMat( &data );
cvReleaseMat( &responses );

return 0;
}

int main( int argc, char *argv[] )
{
char* filename_to_save = 0;
char* filename_to_load = 0;
char default_data_filename[] = "F:\\Program Files\\OpenCV2.4.0\\samples\\cpp\\letter-recognition.data";
char* data_filename = default_data_filename;
int method = 0;

int i;
for( i = 1; i < argc; i++ )
{
if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml"
{
i++;
data_filename = argv[i];
}
else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml"
{
i++;
filename_to_save = argv[i];
}
else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml"
{
i++;
filename_to_load = argv[i];
}
//else if( strcmp(argv[i],"-boost") == 0)
//{
//	method = 1;
//}
//else if( strcmp(argv[i],"-mlp") == 0 )
//{
//	method = 2;
//}
//else if ( strcmp(argv[i], "-knearest") == 0)
//{
//	method = 3;
//}
//else if ( strcmp(argv[i], "-nbayes") == 0)
//{
//	method = 4;
//}
//else if ( strcmp(argv[i], "-svm") == 0)
//{
//	method = 5;
//}
else
break;
}

if( i < argc ||
(method == 0 ?
build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) :
//method == 1 ?
//	build_boost_classifier( data_filename, filename_to_save, filename_to_load ) :
//method == 2 ?
//	build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) :
//method == 3 ?
//	build_knearest_classifier( data_filename, 10 ) :
//method == 4 ?
//	build_nbayes_classifier( data_filename) :
//method == 5 ?
//	build_svm_classifier( data_filename ):
-1) < 0)
{
help();
}
return 0;
}


运行结果:



另参考:使用CvRTrees类对手写体数据作分类
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息