基于树型弱分类器的adaboot演示(含Matlab代码)

来源:互联网 发布:java数据分析算法 编辑:程序博客网 时间:2024/05/29 04:51

Adaboost是一种非常有用的分类框架[1]。 本质上,它将众多的弱分类器进行线性组合,最终形成一个可以与所谓的强分类器如SVM比拟的分类器。它的优点在于速度快,过拟合不严重等,缺点是需解带权重的离散误差最小化问题,使得只有少数的弱分类器能比较方便地得到最优解,从而限制了它的应用。


在此处对adaboost+只有一个根结点的决策树进行演示。

训练代码:

%stump_train.m

function [stump,err] = stump_train(x,y,d)

[stump1,err1] = stump_train_1d(x(1,:),y,d);
[stump2,err2] = stump_train_1d(x(2,:),y,d);

if err1 < err2
    stump.dim = 1;
    stump.s = stump1.s;
     stump.t = stump1.t;
     err = err1;
else
       stump.dim = 2;
    stump.s = stump2.s;
     stump.t = stump2.t;
     err = err2;
end




function [stump,err] = stump_train_1d(data,label,weight)
%find min_x max_x
min_x  = min(data);
max_x = max(data);
N = length(data);
min_distance = inf;
for i=1:N
    for j=1:i-1
        if min_distance > abs(data(i)-data(j))
            min_distance = abs(data(i)-data(j));
        end
    end
end
min_distance = max(min_distance,0.05);
min_err = 1;
for t = min_x+min_distance/2:min_distance/2:max_x
    stump1.s = 1;
    stump1.t = t;
    err1 = computeErr(stump1,data,label,weight);
    stump2.s = -1;
    stump2.t = t;
    err2 = computeErr(stump2,data,label,weight);
    
    if min(err1,err2) < min_err
        min_err = min(err1,err2);
        if err1 < err2 
            final_stump.s = 1;
            final_stump.t = t;
        else
             final_stump.s = -1;
            final_stump.t = t;
        end
    end
end
    stump = final_stump;
    err  = min_err;   




function err = computeErr(stump,data,label,weight)
err = 0;
for i=1:length(data)
 if stump.s*data(i) < stump.t
     h = -1;
 else
     h = 1;
 end
 if h~=label(i)
     err = err + weight(i);
 end
end

单个树形分类器的识别代码:

function y = stump_predict(x,stump)
if stump.s*x(stump.dim) > stump.t
    y = 1;
else
    y = -1;
end
end


给定样本序列x, y, 计算adaboost的误差:

function err = boost_error(boost,x,y)

N = length(y);
T = length(boost.alpha);
 err = 0;
for i=1:N
    
   s = 0;
    for t=1:T
        s = s + boost.alpha(t)*stump_predict(x(:,i),boost.stump{t});
    end
    
    if s > 0
        h = 1;
    else 
        h = -1;
    end
    
    if  h~= y(i)
        err = err + 1;
    end
end


演示的主程序demo_adaboost.m

%%
clc;
clear;
close all;

%% generate random data
shift =2.0;
n = 2;%2 dim
sigma = 1;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];


%show the data
figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d training data');


%training..
d = ones(N,1)/N;
T = 30;%max No. of weak classifier

for t=1:T
    [stump,err] = stump_train(x,y,d);
    boost.stump{t} = stump;      
    boost.alpha(t) = 0.5*log((1-err)/(err));%0.5*log( (1-et)/et);    
    for i=1:N
         h= stump_predict(x(:,i),stump);
        d(i) = d(i)*exp(- boost.alpha(t)*y(i)*h);
    end
    d = d/sum(d);
    boost_err(t) = boost_error(boost,x,y)/N;
    if boost_err(t) < 1e-5
        fprintf('training error is small enought, err = %f, number of weak classifiers = %d,quit\n',boost_err(t),t);
        break;
    end
   
end


%% show the separation line
hold on;
min_x = min(x(1,:));
min_y = min(x(2,:));
max_x = max(x(1,:));
max_y = max(x(2,:));
for t=1:length(boost.alpha)
    if boost.stump{t}.dim == 1
        line([boost.stump{t}.t,boost.stump{t}.t],[min_y,max_y]);
        text(boost.stump{t}.t,(min_y+max_y)/2+randn(1)*3,[num2str(t) ':' num2str(boost.alpha(t))]);
       
    else
        line([min_x,max_x],[boost.stump{t}.t,boost.stump{t}.t]);
        text((min_x+max_x)/2+randn(1)*3,boost.stump{t}.t,[num2str(t) ':' num2str(boost.alpha(t))]);
       
    end
    

end




%%

figure;
plot(boost_err,'r-s','LineWidth',2);
xlabel('Number of weak classifiers');
ylabel('Overall classification error');
title('error versus number of wek classifiers');





%% test on new dataset, same distribution

n = 2;
sigma = 2;
N = 500;
x = [randn(n,N/2)-shift, randn(n,N/2)*sigma+shift];
y = [ones(N/2,1);-ones(N/2,1)];


figure;
plot(x(1,1:N/2),x(2,1:N/2),'rs');
hold on;
plot(x(1,1+N/2:N),x(2,1+N/2:N),'go');
title('2d training data');
hold on;
min_x = min(x(1,:));
min_y = min(x(2,:));
max_x = max(x(1,:));
max_y = max(x(2,:));
for t=1:length(boost.alpha)
    if boost.stump{t}.dim == 1
        line([boost.stump{t}.t,boost.stump{t}.t],[min_y,max_y]);
        text(boost.stump{t}.t,(min_y+max_y)/2+randn(1)*3,[num2str(t) ':' num2str(boost.alpha(t))]);
       
    else
        line([min_x,max_x],[boost.stump{t}.t,boost.stump{t}.t]);
        text((min_x+max_x)/2+randn(1)*3,boost.stump{t}.t,[num2str(t) ':' num2str(boost.alpha(t))]);
       
    end
    
end


boost_err_test = boost_error(boost,x,y)/N;
fprintf('boost error on test data set: %f\n',boost_err_test);


PS:以上所有代码可以从http://download.csdn.net/detail/ranchlai/6038311下载


参考资料:

[1]http://en.wikipedia.org/wiki/AdaBoost