您的位置:首页 > 其它

欢迎使用CSDN-markdown编辑器

2017-09-12 22:18 351 查看
4000

Adaboost是一种非常有用的分类框架[1]。 本质上,它将众多的弱分类器进行线性组合,最终形成一个可以与所谓的强分类器如SVM比拟的分类器。它的优点在于速度快,过拟合不严重等,缺点是需解带权重的离散误差最小化问题,使得只有少数的弱分类器能比较方便地得到最优解,从而限制了它的应用。

在此处对adaboost+只有一个根结点的决策树进行演示。

训练代码:

%stump_train.m

function [stump,err] = stump_train(x,y,d)

[stump1,err1] = stump_train_1d(x(1,:),y,d);

[stump2,err2] = stump_train_1d(x(2,:),y,d);

if err1 < err2

stump.dim = 1;

stump.s = stump1.s;

stump.t = stump1.t;

err = err1;

else

stump.dim = 2;

stump.s = stump2.s;

stump.t = stump2.t;

err = err2;

end

function [stump,err] = stump_train_1d(data,label,weight)

%find min_x max_x

min_x = min(data);

max_x = max(data);

N = length(data);

min_distance = inf;

for i=1:N

for j=1:i-1

if min_distance > abs(data(i)-data(j))

min_distance = abs(data(i)-data(j));

end

end

end

min_distance = max(min_distance,0.05);

min_err = 1;

for t = min_x+min_distance/2:min_distance/2:max_x

stump1.s = 1;

stump1.t = t;

err1 = computeErr(stump1,data,label,weight);

stump2.s = -1;

stump2.t = t;

err2 = computeErr(stump2,data,label,weight);

if min(err1,err2) < min_err

min_err = min(err1,err2);

if err1 < err2

final_stump.s = 1;

final_stump.t = t;

else

final_stump.s = -1;

final_stump.t = t;

end

end

end

stump = final_stump;

err = min_err;

function err = computeErr(stump,data,label,weight)

err = 0;

for i=1:length(data)

if stump.s*data(i) < stump.t

h = -1;

else

h = 1;

end

if h~=label(i)

err = err + weight(i);

end

end

单个树形分类器的识别代码:

function y = stump_predict(x,stump)

if stump.s*x(stump.dim) > stump.t

y = 1;

else

y = -1;

end

end

给定样本序列x, y, 计算adaboost的误差:

function err = boost_error(boost,x,y)

N = length(y);

T = length(boost.alpha);

err = 0;

for i=1:N

s = 0;

for t=1:T

s = s + boost.alpha(t)*stump_predict(x(:,i),boost.stump{t});

end

if s > 0

h = 1;

else

h = -1;

end

if h~= y(i)

err = err + 1;

end

end

演示的主程序demo_adaboost.m

%%

clc;

clear;

close all;

%% generate random data

shift =2.0;

n = 2;%2 dim

sigma = 1;

N = 500;

x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];

y = [ones(N/2,1);-ones(N/2,1)];

%show the data

figure;

plot(x(1,1:N/2),x(2,1:N/2),’rs’);

hold on;

plot(x(1,1+N/2:N),x(2,1+N/2:N),’go’);

title(‘2d training data’);



%training..

d = ones(N,1)/N;

T = 30;%max No. of weak classifier

for t=1:T

[stump,err] = stump_train(x,y,d);

boost.stump{t} = stump;

boost.alpha(t) = 0.5*log((1-err)/(err));%0.5*log( (1-et)/et);

for i=1:N

h= stump_predict(x(:,i),stump);

d(i) = d(i)*exp(- boost.alpha(t)*y(i)*h);

end

d = d/sum(d);

boost_err(t) = boost_error(boost,x,y)/N;

if boost_err(t) < 1e-5

fprintf(‘training error is small enought, err = %f, number of weak classifiers = %d,quit\n’,boost_err(t),t);

break;

end

end

%% show the separation line

hold on;

min_x = min(x(1,:));

min_y = min(x(2,:));

max_x = max(x(1,:));

max_y = max(x(2,:));

for t=1:length(boost.alpha)

if boost.stump{t}.dim == 1

line([boost.stump{t}.t,boost.stump{t}.t],[min_y,max_y]);

text(boost.stump{t}.t,(min_y+max_y)/2+randn(1)*3,[num2str(t) ‘:’ num2str(boost.alpha(t))]);

else

line([min_x,max_x],[boost.stump{t}.t,boost.stump{t}.t]);

text((min_x+max_x)/2+randn(1)*3,boost.stump{t}.t,[num2str(t) ‘:’ num2str(boost.alpha(t))]);

end

end



%%

figure;

plot(boost_err,’r-s’,’LineWidth’,2);

xlabel(‘Number of weak classifiers’);

ylabel(‘Overall classification error’);

title(‘error versus number of wek classifiers’);



%% test on new dataset, same distribution

n = 2;

sigma = 2;

N = 500;

x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];

y = [ones(N/2,1);-ones(N/2,1)];

figure;

plot(x(1,1:N/2),x(2,1:N/2),’rs’);

hold on;

plot(x(1,1+N/2:N),x(2,1+N/2:N),’go’);

title(‘2d training data’);

hold on;

min_x = min(x(1,:));

min_y = min(x(2,:));

max_x = max(x(1,:));

max_y = max(x(2,:));

for t=1:length(boost.alpha)

if boost.stump{t}.dim == 1

line([boost.stump{t}.t,boost.stump{t}.t],[min_y,max_y]);

text(boost.stump{t}.t,(min_y+max_y)/2+randn(1)*3,[num2str(t) ‘:’ num2str(boost.alpha(t))]);

else

line([min_x,max_x],[boost.stump{t}.t,boost.stump{t}.t]);

text((min_x+max_x)/2+randn(1)*3,boost.stump{t}.t,[num2str(t) ‘:’ num2str(boost.alpha(t))]);

end

end

boost_err_test = boost_error(boost,x,y)/N;

fprintf(‘boost error on test data set: %f\n’,boost_err_test);



PS:以上所有代码可以从http://download.csdn.net/detail/ranchlai/6038311下载

参考资料:

[1]http://en.wikipedia.org/wiki/AdaBoost
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: