OpenCV中随机森林的实现与字符识别例子
2014-05-13 15:56
573 查看
之前一篇文章简单介绍了随机森林,并且给出来了一些随机森林的资源:http://blog.csdn.net/holybin/article/details/25653597
在opencv中随机森林的实现为CvRTrees类,version2.0及以下版本定义于\OpenCV2.0\include\opencv\ml.h,version2.0以上版本定义于\OpenCV2.4.0\modules\ml\include\opencv2\ml\ml.hpp,实现在\OpenCV2.4.0\modules\ml\src\rtress.cpp,版本不同实现略有不同。
以opencv2.4.0为例子,其定义如下:
这里的CvStatModel是OpenCV的机器学习模块(The Machine Learning Library,MLL)的基类,包括KNN,Bayes,SVM等诸多实现都是基于该类,参考opencv的文档:docs.opencv.org/modules/ml/doc/ml.html,附个SVM的使用例子:docs.opencv.org/doc/tutorials/ml/non_linear_svms/non_linear_svms.html#nonlinearsvms。
以下使用CvRTrees类来对字符数据作分类,该例子即opencv附带的例子“\OpenCV2.4.0\samples\cpp\letter_recog.cpp”,字符数据“\OpenCV2.4.0\samples\cpp\letter-recognition.data”来源于UCI,还有一个csv格式的,这个网站还有很多很好的机器学习数据库。
在本例子中,字符数据“letter-recognition.data”有20000个训练字母,每一字母用16维的特征表示:
1.lettrcapital
letter(26 values from A to Z)
2.x-boxhorizontal
position of box(integer)
3.y-boxvertical
position of box(integer)
4.widthwidth
of box(integer)
5.high height
of box(integer)
6.onpixtotal
# on pixels(integer)
7.x-barmean
x of on pixels in box(integer)
8.y-barmean
y of on pixels in box(integer)
9.x2barmean
x variance(integer)
10.y2barmean
y variance(integer)
11.xybarmean
x y correlation(integer)
12.x2ybrmean
of x * x * y(integer)
13.xy2brmean
of x * y * y(integer)
14.x-egemean
edge count left to right(integer)
15.xegvycorrelation
of x-ege with y(integer)
16.y-egemean
edge count bottom to top(integer)
17.yegvxcorrelation
of y-ege with x(integer)
程序中使用前16000个进行训练,后4000个进行测试。
运行结果:
另参考:使用CvRTrees类对手写体数据作分类
在opencv中随机森林的实现为CvRTrees类,version2.0及以下版本定义于\OpenCV2.0\include\opencv\ml.h,version2.0以上版本定义于\OpenCV2.4.0\modules\ml\include\opencv2\ml\ml.hpp,实现在\OpenCV2.4.0\modules\ml\src\rtress.cpp,版本不同实现略有不同。
以opencv2.4.0为例子,其定义如下:
class CV_EXPORTS_W CvRTrees : public CvStatModel { public: CV_WRAP CvRTrees(); virtual ~CvRTrees(); virtual bool train( const CvMat* trainData, int tflag, const CvMat* responses, const CvMat* varIdx=0, const CvMat* sampleIdx=0, const CvMat* varType=0, const CvMat* missingDataMask=0, CvRTParams params=CvRTParams() ); virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const; #ifndef SWIG CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), const cv::Mat& missingDataMask=cv::Mat(), CvRTParams params=CvRTParams() ); CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; CV_WRAP virtual cv::Mat getVarImportance(); #endif CV_WRAP virtual void clear(); virtual const CvMat* get_var_importance(); virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} virtual float get_train_error(); virtual void read( CvFileStorage* fs, CvFileNode* node ); virtual void write( CvFileStorage* fs, const char* name ) const; CvMat* get_active_var_mask(); CvRNG* get_rng(); int get_tree_count() const; CvForestTree* get_tree(int i) const; protected: virtual std::string getName() const; virtual bool grow_forest( const CvTermCriteria term_crit ); // array of the trees of the forest CvForestTree** trees; CvDTreeTrainData* data; int ntrees; int nclasses; double oob_error; CvMat* var_importance; int nsamples; cv::RNG* rng; CvMat* active_var_mask; };
这里的CvStatModel是OpenCV的机器学习模块(The Machine Learning Library,MLL)的基类,包括KNN,Bayes,SVM等诸多实现都是基于该类,参考opencv的文档:docs.opencv.org/modules/ml/doc/ml.html,附个SVM的使用例子:docs.opencv.org/doc/tutorials/ml/non_linear_svms/non_linear_svms.html#nonlinearsvms。
以下使用CvRTrees类来对字符数据作分类,该例子即opencv附带的例子“\OpenCV2.4.0\samples\cpp\letter_recog.cpp”,字符数据“\OpenCV2.4.0\samples\cpp\letter-recognition.data”来源于UCI,还有一个csv格式的,这个网站还有很多很好的机器学习数据库。
在本例子中,字符数据“letter-recognition.data”有20000个训练字母,每一字母用16维的特征表示:
1.lettrcapital
letter(26 values from A to Z)
2.x-boxhorizontal
position of box(integer)
3.y-boxvertical
position of box(integer)
4.widthwidth
of box(integer)
5.high height
of box(integer)
6.onpixtotal
# on pixels(integer)
7.x-barmean
x of on pixels in box(integer)
8.y-barmean
y of on pixels in box(integer)
9.x2barmean
x variance(integer)
10.y2barmean
y variance(integer)
11.xybarmean
x y correlation(integer)
12.x2ybrmean
of x * x * y(integer)
13.xy2brmean
of x * y * y(integer)
14.x-egemean
edge count left to right(integer)
15.xegvycorrelation
of x-ege with y(integer)
16.y-egemean
edge count bottom to top(integer)
17.yegvxcorrelation
of y-ege with x(integer)
程序中使用前16000个进行训练,后4000个进行测试。
#include "opencv2/core/core_c.h" #include "opencv2/ml/ml.hpp" #include <cstdio> #include <vector> /* Modified from F:\Program Files\OpenCV2.4.0\samples\cpp\letter_recog.cpp Only RF method reserved. */ void help() { printf("\nThe sample demonstrates how to train Random Trees classifier\n" "(or Boosting classifier, or MLP, or Knearest, or Nbayes, or Support Vector Machines - see main()) using the provided dataset.\n" "\n" "We use the sample database letter-recognition.data\n" "from UCI Repository, here is the link:\n" "\n" "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n" "UCI Repository of machine learning databases\n" "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n" "Irvine, CA: University of California, Department of Information and Computer Science.\n" "\n" "The dataset consists of 20000 feature vectors along with the\n" "responses - capital latin letters A..Z.\n" "The first 16000 (10000 for boosting)) samples are used for training\n" "and the remaining 4000 (10000 for boosting) - to test the classifier.\n" "======================================================\n"); printf("\nThis is letter recognition sample.\n" "The usage: letter_recog [-data <path to letter-recognition.data>] \\\n" " [-save <output XML file for the classifier>] \\\n" " [-load <XML file with the pre-trained classifier>] \\\n" //" [-boost|-mlp|-knearest|-nbayes|-svm] # to use boost/mlp/knearest/SVM classifier instead of default Random Trees\n" ); } // This function reads data and responses from the file <filename> static int read_num_class_data( const char* filename, int var_count, CvMat** data, CvMat** responses ) { const int M = 1024; FILE* f = fopen( filename, "rt" ); CvMemStorage* storage; CvSeq* seq; char buf[M+2]; float* el_ptr; CvSeqReader reader; int i, j; if( !f ) return 0; el_ptr = new float[var_count+1]; storage = cvCreateMemStorage(); seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage ); for(;;) { char* ptr; if( !fgets( buf, M, f ) || !strchr( buf, ',' ) ) break; el_ptr[0] = buf[0]; ptr = buf+2; for( i = 1; i <= var_count; i++ ) { int n = 0; sscanf( ptr, "%f%n", el_ptr + i, &n ); ptr += n + 1; } if( i <= var_count ) break; cvSeqPush( seq, el_ptr ); } fclose(f); *data = cvCreateMat( seq->total, var_count, CV_32F ); *responses = cvCreateMat( seq->total, 1, CV_32F ); cvStartReadSeq( seq, &reader ); for( i = 0; i < seq->total; i++ ) { const float* sdata = (float*)reader.ptr + 1; float* ddata = data[0]->data.fl + var_count*i; float* dr = responses[0]->data.fl + i; for( j = 0; j < var_count; j++ ) ddata[j] = sdata[j]; *dr = sdata[-1]; CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); } cvReleaseMemStorage( &storage ); delete[] el_ptr; return 1; } static int build_rtrees_classifier( char* data_filename, char* filename_to_save, char* filename_to_load ) { CvMat* data = 0; CvMat* responses = 0; CvMat* var_type = 0; CvMat* sample_idx = 0; int ok = read_num_class_data( data_filename, 16, &data, &responses ); int nsamples_all = 0, ntrain_samples = 0; int i = 0; double train_hr = 0, test_hr = 0; CvRTrees forest; CvMat* var_importance = 0; if( !ok ) { printf( "Could not read the database %s\n", data_filename ); return -1; } printf( "The database %s is loaded.\n", data_filename ); nsamples_all = data->rows; ntrain_samples = (int)(nsamples_all*0.8); // Create or load Random Trees classifier if( filename_to_load ) { // load classifier from the specified file forest.load( filename_to_load ); ntrain_samples = 0; if( forest.get_tree_count() == 0 ) { printf( "Could not read the classifier %s\n", filename_to_load ); return -1; } printf( "The classifier %s is loaded.\n", data_filename ); } else { // create classifier by using <data> and <responses> printf( "Training the classifier ...\n"); // 1. create type mask var_type = cvCreateMat( data->cols + 1, 1, CV_8U ); cvSet( var_type, cvScalarAll(CV_VAR_ORDERED) ); cvSetReal1D( var_type, data->cols, CV_VAR_CATEGORICAL ); // 2. create sample_idx sample_idx = cvCreateMat( 1, nsamples_all, CV_8UC1 ); { CvMat mat; cvGetCols( sample_idx, &mat, 0, ntrain_samples ); cvSet( &mat, cvRealScalar(1) ); cvGetCols( sample_idx, &mat, ntrain_samples, nsamples_all ); cvSetZero( &mat ); } // 3. train classifier forest.train( data, CV_ROW_SAMPLE, responses, 0, sample_idx, var_type, 0, CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER)); printf( "\n"); } // compute prediction error on train and test data for( i = 0; i < nsamples_all; i++ ) { double r; CvMat sample; cvGetRow( data, &sample, i ); r = forest.predict( &sample ); r = fabs((double)r - responses->data.fl[i]) <= FLT_EPSILON ? 1 : 0; if( i < ntrain_samples ) train_hr += r; else test_hr += r; } test_hr /= (double)(nsamples_all-ntrain_samples); train_hr /= (double)ntrain_samples; printf( "Recognition rate: train = %.1f%%, test = %.1f%%\n", train_hr*100., test_hr*100. ); printf( "Number of trees: %d\n", forest.get_tree_count() ); // Print variable importance var_importance = (CvMat*)forest.get_var_importance(); if( var_importance ) { double rt_imp_sum = cvSum( var_importance ).val[0]; printf("var#\timportance (in %%):\n"); for( i = 0; i < var_importance->cols; i++ ) printf( "%-2d\t%-4.1f\n", i, 100.f*var_importance->data.fl[i]/rt_imp_sum); } //Print some proximitites printf( "Proximities between some samples corresponding to the letter 'T':\n" ); { CvMat sample1, sample2; const int pairs[][2] = {{0,103}, {0,106}, {106,103}, {-1,-1}}; for( i = 0; pairs[i][0] >= 0; i++ ) { cvGetRow( data, &sample1, pairs[i][0] ); cvGetRow( data, &sample2, pairs[i][1] ); printf( "proximity(%d,%d) = %.1f%%\n", pairs[i][0], pairs[i][1], forest.get_proximity( &sample1, &sample2 )*100. ); } } // Save Random Trees classifier to file if needed if( filename_to_save ) forest.save( filename_to_save ); cvReleaseMat( &sample_idx ); cvReleaseMat( &var_type ); cvReleaseMat( &data ); cvReleaseMat( &responses ); return 0; } int main( int argc, char *argv[] ) { char* filename_to_save = 0; char* filename_to_load = 0; char default_data_filename[] = "F:\\Program Files\\OpenCV2.4.0\\samples\\cpp\\letter-recognition.data"; char* data_filename = default_data_filename; int method = 0; int i; for( i = 1; i < argc; i++ ) { if( strcmp(argv[i],"-data") == 0 ) // flag "-data letter_recognition.xml" { i++; data_filename = argv[i]; } else if( strcmp(argv[i],"-save") == 0 ) // flag "-save filename.xml" { i++; filename_to_save = argv[i]; } else if( strcmp(argv[i],"-load") == 0) // flag "-load filename.xml" { i++; filename_to_load = argv[i]; } //else if( strcmp(argv[i],"-boost") == 0) //{ // method = 1; //} //else if( strcmp(argv[i],"-mlp") == 0 ) //{ // method = 2; //} //else if ( strcmp(argv[i], "-knearest") == 0) //{ // method = 3; //} //else if ( strcmp(argv[i], "-nbayes") == 0) //{ // method = 4; //} //else if ( strcmp(argv[i], "-svm") == 0) //{ // method = 5; //} else break; } if( i < argc || (method == 0 ? build_rtrees_classifier( data_filename, filename_to_save, filename_to_load ) : //method == 1 ? // build_boost_classifier( data_filename, filename_to_save, filename_to_load ) : //method == 2 ? // build_mlp_classifier( data_filename, filename_to_save, filename_to_load ) : //method == 3 ? // build_knearest_classifier( data_filename, 10 ) : //method == 4 ? // build_nbayes_classifier( data_filename) : //method == 5 ? // build_svm_classifier( data_filename ): -1) < 0) { help(); } return 0; }
运行结果:
另参考:使用CvRTrees类对手写体数据作分类
相关文章推荐
- Python+OpenCV实现车牌字符分割和识别
- opencv实现车牌识别之字符分割
- opencv实现车牌识别之字符识别
- opencv实现车牌识别之字符分割
- opencv实现车牌识别之字符分割
- opencv(三):HOG+SVM实现手写字符识别
- OpenCV手写数字字符识别(基于k近邻算法)
- java实现随机森林RandomForest的示例代码
- 关于递归实现字符串反转,没想到字符随机写入操作,不new就不行?
- OpenCV进阶之路:神经网络识别车牌字符
- OpenCV进阶之路:神经网络识别车牌字符
- Linux下基于opencv的神经网络字符识别
- 利用opencv的hog+svm实现细胞识别分类器
- 从零使用OpenCV快速实现简单车牌识别系统
- OpenCV进阶之路:神经网络识别车牌字符
- “OpenCV ERROR: Insufficient memory”解决方法(如:用随机森林来进行预测样本过大时报错解决方法)
- 字符识别opencv c++版
- 随机森林算法原理及OpenCV应用
- OpenCV中人脸识别代码实现
- Java使用OpenCV实现人脸识别/人眼检测/图片截取/合成/添加水印