您的位置:首页 > 其它

SVM实现一个三类分类问题

2013-04-06 10:01 399 查看
任务要求
用SVM求解一个三类分类问题,实验数据为“鸢尾属植物数据集”,核函数为径向基核函数(RBF),误差评测标准为K折交叉确认误差。

二.实验方案
1. 用quadprog函数实现C-SVC来进行分类#此前在首页部分显示#
——quadprog是matlab中一个求解二次规划的函数,通过适当的参数设置,可以利用quadprog函数实现C-SVC
2. 用matlab自带的SVM工具包来实现分类
——matlab2006版本中集成了SVM工具包,可以通过调用工具包中的svmtrain和svmclassify函数来进行训练和分类
3. 三类问题的分类方法
——将三类问题转化为三个两类问题,分别求出相应的决策函数即可(优点:方法简单易行;缺点:容易形成死区)

三.实验程序
1. 用Quadprog实现


clear all


% Load the data and select features for classification


load fisheriris;


data = meas;


%Get the size of the data


N = size(data,1);


% Extract the Setosa class


groups_temp = ismember(species,'versicolor');%versicolor,virginica,setosa


%convert the group to 1 & -1


groups = 2*groups_temp - ones(N,1);




indices = crossvalind('Kfold', groups);




ErrorMin = 1;


for r=1:1:5


for C=1:1:5


ErrorNum = 0;


for i=1:5


%Use K-fold to get train data and test data


test = (indices == i); train = ~test;




traindata = data(train,:);


traingroup = groups(train,:);


trainlength = length(traingroup);




testdata = data(test,:);


testgroup = groups(test,:);


testlength = length(testgroup);




%Get matrix H of the problem


kfun = [];


for i=1:1:trainlength


for j=1:1:trainlength


%rbf kernel


kfun(i,j)=exp(-1/(r^2)*(traindata(i,:)-traindata(j,:))*(traindata(i,:)-traindata(j,:))');


end


end




%count parameters of quadprog function


H = (traingroup*traingroup').*kfun;


xstart = zeros(trainlength,1);


f = -ones(trainlength,1);


Aeq = traingroup';


beq = 0;


lb = zeros(trainlength,1);


ub = C*ones(trainlength,1);




[alpha,fval] = quadprog(H,f,[],[],Aeq,beq,lb,ub,xstart);




%Get one of the non-zero part of vector alpha to count b


j = 1;


for i=1:size(alpha)


if(alpha(i)>(1e-5))


SvmClass_temper(j,:) = traingroup(i);


SvmAlpha_temper(j,:) = alpha(i);


SvmVector_temper(j,:)= traindata(i,:);


j = j + 1;


tag = i;


end


end




b=traingroup(tag)-(alpha.*traingroup)'*kfun(:,tag);




%Use the function to test the test data


kk = [];


for i=1:testlength


for j=1:trainlength


kk(i,j)=exp(-1/(r^2)*(testdata(i,:)-traindata(j,:))*(testdata(i,:)-traindata(j,:))');


end


end




%then count the function


f=(alpha.*traingroup)'*kk' + b;


for i=1:length(f)


if(f(i)>(1e-5))


f(i)=1;


else


f(i)=-1;


end


end




for i=1:length(f)


if(testgroup(i)~=f(i))


ErrorNum = ErrorNum + 1;


end


end


end




ErrorRate = ErrorNum / N;




if(ErrorRate<ErrorMin)


SvmClass = SvmClass_temper;


SvmAlpha = SvmAlpha_temper;


SvmVector = SvmVector_temper;


ErrorMin = ErrorRate;

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐