您的位置:首页 > Web前端 > Node.js

通过源码学算法--AdaBoost (CART): RealAdaBoost.m + tree_node_w.m

2013-03-30 09:15 453 查看

tree_node_w.m

代表分类树的类结构。很简单,只有5个成员

如果是左树则只有right_constraint有值,如果是右树则只有left_constraint有值

实际上在这里是一个类多用了。

比如一个训练好的最大深度为3(max_split == 3)的CART 树有四个节点(node),每个节点就是该类tree_node_w的一个object,同时每个节点又是一个弱分类器 (weak classifier)

如果循环训练100次(Max_Iter == 100),就有100棵CART树,即400个弱分类器。每个分类器都有对应的权重

function tree_node = tree_node_w(max_split)

tree_node.left_constrain  = [];
tree_node.right_constrain = [];
tree_node.dim             = [];
tree_node.max_split       = max_split;
tree_node.parent         = [];

tree_node = class(tree_node, 'tree_node_w') ;


RealAdaBoost.m

权重分布初始化

Learners = {};
Weights = [];
distr = ones(1, length(Data)) / length(Data);
final_hyp = zeros(1, length(Data));


循环训练100次

每次得到一个4节点的CART树(也就是4个弱分类器)

第一个节点是root,没有parent。其他3个节点都有parent,最终指回root

用每个弱分类器对训练数据分类,根据公式计算/调整alpha值

根据该CART树的最终分类结果调整权重分布distr

归一化权重分布

for It = 1 : Max_Iter
nodes = train(WeakLrn, Data, Labels, distr);

for i = 1:length(nodes)
curr_tr = nodes{i};
step_out = calc_output(curr_tr, Data);

s1 = sum( (Labels ==  1) .* (step_out) .* distr);
s2 = sum( (Labels == -1) .* (step_out) .* distr);

if(s1 == 0 && s2 == 0)
continue;
end
Alpha = 0.5*log((s1 + eps) / (s2+eps));

Weights(end+1) = Alpha;

Learners{end+1} = curr_tr;

final_hyp = final_hyp + step_out .* Alpha;
end

distr = exp(- 1 * (Labels .* final_hyp));
Z = sum(distr);
distr = distr / Z;

end
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: