您的位置:首页 > 编程语言 > MATLAB

利用SVM实现一个三类分类问题

2007-04-27 15:50 519 查看
 
一.任务要求
用SVM求解一个三类分类问题,实验数据为“鸢尾属植物数据集”,核函数为径向基核函数(RBF),误差评测标准为K折交叉确认误差。
 
二.实验方案
1. 用quadprog函数实现C-SVC来进行分类
——quadprog是matlab中一个求解二次规划的函数,通过适当的参数设置,可以利用quadprog函数实现C-SVC
2. 用matlab自带的SVM工具包来实现分类
——matlab2006版本中集成了SVM工具包,可以通过调用工具包中的svmtrain和svmclassify函数来进行训练和分类
3. 三类问题的分类方法
——将三类问题转化为三个两类问题,分别求出相应的决策函数即可(优点:方法简单易行;缺点:容易形成死区)
 
三.实验程序
1. 用Quadprog实现


clear all


% Load the data and select features for classification


load fisheriris;


data = meas;


%Get the size of the data


N = size(data,1);


% Extract the Setosa class


groups_temp = ismember(species,'versicolor');%versicolor,virginica,setosa


%convert the group to 1 & -1


groups = 2*groups_temp - ones(N,1);




indices = crossvalind('Kfold', groups);




ErrorMin = 1;


for r=1:1:5


    for C=1:1:5


        ErrorNum = 0;        


        for i=1:5


            %Use K-fold to get train data and test data


        
4000
    test = (indices == i); train = ~test;


            


            traindata = data(train,:);


            traingroup = groups(train,:);


            trainlength = length(traingroup);


            


            testdata = data(test,:);


            testgroup = groups(test,:);


            testlength = length(testgroup);


            


            %Get matrix H of the problem


            kfun = [];


            for i=1:1:trainlength


                for j=1:1:trainlength


                    %rbf kernel


                    kfun(i,j)=exp(-1/(r^2)*(traindata(i,:)-traindata(j,:))*(traindata(i,:)-traindata(j,:))');


                end


            end




            %count parameters of quadprog function


            H = (traingroup*traingroup').*kfun;


            xstart = zeros(trainlength,1);


            f = -ones(trainlength,1);


            Aeq = traingroup';


            beq = 0;


            lb = zeros(trainlength,1);


            ub = C*ones(trainlength,1);


            


            [alpha,fval] = quadprog(H,f,[],[],Aeq,beq,lb,ub,xstart);


            


            %Get one of the non-zero part of vector alpha to count b


            j = 1;


            for i=1:size(alpha)


                if(alpha(i)>(1e-5))


                    SvmClass_temper(j,:) = traingroup(i);


                    SvmAlpha_temper(j,:) = alpha(i);


                    SvmVector_temper(j,:)= traindata(i,:);


                    j = j + 1;


                    tag = i;


                end


            end


            


            b=traingroup(tag)-(alpha.*traingroup)'*kfun(:,tag);


            


            %Use the function to test the test data


            kk = [];


            for i=1:testlength


                for j=1:trainlength


                    kk(i,j)=exp(-1/(r^2)*(testdata(i,:)-traindata(j,:))*(testdata(i,:)-traindata(j,:))');


                end


            end




            %then count the function


            f=(alpha.*traingroup)'*kk' + b;           


            for i=1:length(f)


                if(f(i)>(1e-5))


                    f(i)=
e21d
1;


                else


                    f(i)=-1;


                end


            end         


            


            for i=1:length(f)


                if(testgroup(i)~=f(i))


                    ErrorNum = ErrorNum + 1;


                end


            end          


        end


        


        ErrorRate = ErrorNum / N;


        


        if(ErrorRate<ErrorMin)


            SvmClass = SvmClass_temper;


            SvmAlpha = SvmAlpha_temper;


            SvmVector = SvmVector_temper;


            ErrorMin = ErrorRate;


            CorrectRate = 1 - ErrorRate;


            Coptimal = C;


            Roptimal = r;


        end


        


    end


end            

 
2. 用SVM工具包实现


clear all


% Load the data and select features for classification


load fisheriris


% data = [meas(:,3),meas(:,4)];


data=meas;


% Extract the Setosa class


groups = ismember(species,'versicolor');%versicolor,virginica,setosa


% Randomly select training and test sets


index = crossvalind('Kfold',groups);


cp = classperf(groups);




fr=0;


fc=0;


fcorrect=0;


correct5=0;




for r=1:1:10


    for c=1:1:100


        for i=1:5


            test = (index == i); train = ~test;


            % Use a RBF support vector machine classifier


            %         svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',5,'boxconstraint',1000,'showplot',true);


            %         classes = svmclassify(svmStruct,data(test,:),'showplot',true);


            svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',1/(r^2),'boxconstraint',c);


            classes = svmclassify(svmStruct,data(test,:));


            % See how well the classifier performed


            classperf(cp,classes,test);


            %             cp.CorrectRate


            correct5=correct5+cp.CorrectRate/5;


        end


        r


        c


        correct5


        if(fcorrect<correct5)


            fcorrect=correct5


            fr=r


            fc=c


        end


        correct5=0;


    end


end

 
四.实验结果
1. Quadprog实现
(1)类别:versicolor 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9696 Roptimal =1   Coptimal =2
 (2)类别:virginica 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9430 Roptimal =1   Coptimal =2
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1
2. SVM工具包实现
(1)类别:versicolor 参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =1 Roptimal =2   Coptimal =22
(2)类别:virginica  参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =0.9867 Roptimal =10   Coptimal =44
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1
 
 
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐