您的位置:首页 > 其它

使用SVM对多类多维数据进行分类

2014-10-05 10:20 741 查看
最近,本人要做个小东西,使用SVM对8类三维数据进行分类,搜索网上,发现大伙讨论的都是二维数据的二分类问题,遂决定自己研究一番。本人首先参考了opencv的tutorial,这也是二维数据的二分类问题。然后通过学习研究,发现别有洞天,遂实现之前的目标。在这里将代码贴出来,这里实现了对三维数据进行三类划分,以供大家相互学习。

#include "stdafx.h"
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;
using namespace std;

int main()
{

//--------------------- 1. Set up training data randomly ---------------------------------------
Mat trainData(100, 3, CV_32FC1);
Mat labels   (100, 1, CV_32FC1);

RNG rng(100); // Random value generation class

// Generate random points for the class 1
Mat trainClass = trainData.rowRange(0, 40);
// The x coordinate of the points is in [0, 0.4)
Mat c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));
// The y coordinate of the points is in [0, 0.4)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));
// The z coordinate of the points is in [0, 0.4)
c = trainClass.colRange(2, 3);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * 100));

// Generate random points for the class 2
trainClass = trainData.rowRange(60, 100);
// The x coordinate of the points is in [0.6, 1]
c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));
// The y coordinate of the points is in [0.6, 1)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));
// The z coordinate of the points is in [0.6, 1]
c = trainClass.colRange(2, 3);
rng.fill(c, RNG::UNIFORM, Scalar(0.6*100), Scalar(100));

// Generate random points for the classes 3
trainClass = trainData.rowRange(  40, 60);
// The x coordinate of the points is in [0.4, 0.6)
c = trainClass.colRange(0,1);
rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));
// The y coordinate of the points is in [0.4, 0.6)
c = trainClass.colRange(1,2);
rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));
// The z coordinate of the points is in [0.4, 0.6)
c = trainClass.colRange(2,3);
rng.fill(c, RNG::UNIFORM, Scalar(0.4*100), Scalar(0.6*100));

//------------------------- Set up the labels for the classes ---------------------------------
labels.rowRange( 0,  40).setTo(1);  // Class 1
labels.rowRange(60, 100).setTo(2);  // Class 2
labels.rowRange(40, 60).setTo(3);  // Class 3

//------------------------ 2. Set up the support vector machines parameters --------------------
CvSVMParams params;
params.svm_type    = SVM::C_SVC;
params.C           = 0.1;
params.kernel_type = SVM::LINEAR;
params.term_crit   = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);

//------------------------ 3. Train the svm ----------------------------------------------------
cout << "Starting training process" << endl;
CvSVM svm;
svm.train(trainData, labels, Mat(), Mat(), params);
cout << "Finished training process" << endl;

Mat sampleMat = (Mat_<float>(1,3) << 50, 50,10);
float response = svm.predict(sampleMat);
cout<<response<<endl;

sampleMat = (Mat_<float>(1,3) << 50, 50,100);
response = svm.predict(sampleMat);
cout<<response<<endl;

sampleMat = (Mat_<float>(1,3) << 50, 50,60);
response = svm.predict(sampleMat);
cout<<response<<endl;

waitKey(0);
}

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: