您的位置:首页 > 理论基础 > 计算机网络

调用opencv中BP神经网络来对图像进行分类

2011-10-28 16:36 519 查看
/*调用BP进行图像分类只需:
img0=cvLoadImage("099.pbm",0);
BPClassifier bp_cly;//初始化对象.
number=bp_cly.classify(img0);*/

#include <ml.h>
CvANN_MLP BP; //opencv中的BP
//构造函数
BPClassifier:: BPClassifier ()
{

//initial
//sprintf(file_path , "../OCR/");
sprintf(file_path , "C:\\Users\\y450\\Desktop\\recognize\\OCR\\");
train_samples = 80;
classes= 10;
size=40;
trainData = cvCreateMat(train_samples*classes, size*size, CV_32FC1);
trainClasses = cvCreateMat(train_samples*classes, 10, CV_32FC1);
neuralLayers=cvCreateMat(3,1,CV_32SC1);
sampleWts=cvCreateMat(train_samples*classes,1,CV_32FC1);
for (int i=0;i<train_samples*classes;i++)
{
cvSet1D(sampleWts,i,cvScalar(1));
}
cvSet1D(neuralLayers,0,cvScalar(size*size));
cvSet1D(neuralLayers,1,cvScalar(5));
cvSet1D(neuralLayers,2,cvScalar(10));//10个输出
//Get data (get images and process it)
getData();
train();
//BP.load("bp.xml");  //或者调用训练好的数据,把getDate()和train()   //注释掉
printf(" ---------------------------------------------------------------\n");
printf("|\tClass\t|\tPrecision\t|\tAccuracy\t|\n");
printf(" ---------------------------------------------------------------\n");

}

//取得样本数据
void BPClassifier::getData()
{
IplImage* src_image;
IplImage prs_image;
CvMat row,data;
char file[255];
int i,j;
for(i =0; i<classes; i++)
{
for( j = 0; j< train_samples; j++)
{

//Load file
if(j<10)
sprintf(file,"%s%d/%d0%d.pbm",file_path, i, i , j);
else
sprintf(file,"%s%d/%d%d.pbm",file_path, i, i , j);
src_image = cvLoadImage(file,0);
if(!src_image)
{
printf("Error: Cant load image %s\n", file);
//exit(-1);
}
//process file
prs_image = preprocessing(src_image, size, size);

//Set class label
cvGetRow(trainClasses, &row, i*train_samples + j);
cvSet(&row,cvScalarAll(0));
cvSet2D(&row,0,i, cvRealScalar(1));
//Set data
cvGetRow(trainData, &row, i*train_samples + j);

IplImage* img = cvCreateImage( cvSize( size, size ), IPL_DEPTH_32F, 1 );
//convert 8 bits image to 32 float image
cvConvertScale(&prs_image, img, 0.0039215, 0);

cvGetSubRect(img, &data, cvRect(0,0, size,size));

CvMat row_header, *row1;
//convert data matrix sizexsize to vecor
row1 = cvReshape( &data, &row_header, 0, 1 );
cvCopy(row1, &row, NULL);
}
}

//训练
void BPClassifier::train()
{
BP.create(neuralLayers);
printf(" 训练中\n");
BP.train(trainData,
trainClasses,
sampleWts,
0,
CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS,300,0.01),CvANN_MLP_TrainParams::BACKPROP,0.01)
);
printf(" 训练结束\n");
BP.save("bp.xml");

}

//分类
float BPClassifier::classify(IplImage* img)
{
IplImage prs_image;
CvMat data;
CvMat* nearest=cvCreateMat(1,10,CV_32FC1);
cvSet(nearest,cvScalarAll(0));

float result;
//process file
prs_image = preprocessing(img, size, size);

//Set data
IplImage* img32 = cvCreateImage( cvSize( size, size ), IPL_DEPTH_32F, 1 );
cvConvertScale(&prs_image, img32, 0.0039215, 0);
cvGetSubRect(img32, &data, cvRect(0,0, size,size));
CvMat row_header, *row1;
row1 = cvReshape( &data, &row_header, 0, 1 );
cvSet(nearest,cvScalarAll(0));
BP.predict(row1,nearest);
CvPoint max={0,0};
cvMinMaxLoc(nearest,NULL,NULL,NULL,&max,NULL);
//PrintMatrix(nearest,1,10);

int best=max.x;

return (float)best;

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: