决策树(ID3,C4.5,CART)

来源:互联网 发布:复旦投毒案的思考知乎 编辑:程序博客网 时间:2024/06/06 04:22

将训练样本的信息以一颗树的形式表达,大多数算法采用自顶而下递归的方法构建这棵树,其关键在于如何确定分裂准则。

一 信息选择度量

1 ID3(迭代的二分器,Iterative Dichotomiser)—信息增益

这里写图片描述

(1)计算结果中分类的期望—info(D),也称为熵(entropy),它是D中元祖的类标号所需要的平均信息量。比如判断明天是否下雨(n=2),由训练集(D)可知,下雨概率(p1)是0.3,不下雨的概率(p2)是0.7,因此

这里写图片描述

(2)计算按某个属性划分对D的元祖分类所需要的期望信息。

这里写图片描述

j是属性A分类个数,0.3是以A来划分D,每一类的概率,0.4代表在属性A的某个分类中,下雨概率为0.4,则不下雨为0.6

这里写图片描述

(3)两者相减为信息增益(原来的信息需求与新的信息需求之差),也就是说通过A的划分,我们得到了多少信息,因此选择最高的信息增益的属性A分裂,能够使得完成元祖分类还需要的信息量最少。

2 C4.5—增益率
用信息增益来确定分裂,存在一定的偏向性,它更倾向于多输出,具有大量值的属性。为了克服这种偏向性,C4.5使用了‘分裂信息’值将信息增益规范化,同样我们选取最大值

这里写图片描述

3 CART (Classification and Regression Trees)—基尼指数

这里写图片描述

基尼指数独立了数据集的不纯度,当数据类别越多,其混合程度也就越高,基尼指数也就越大(当n=1时,Gini=0)。当以属性A来划分数据集时,分裂后的基尼指数应该最小,这样剩余的数据集混合程度才会最低,也就是差值应最大。

二 剪枝

由于决策树的构建是基于样本数据集,因此该模型的预测效果会受到数据中的噪声与离群点的影响,从而出现过度拟合。通常我们需要减掉不可靠的分支,使树模型更简洁,效果更好。

(1)先剪枝(prepruning):在树的构建过程中,通过一定条件阈值(信息增益,基尼指数等),限制树的生长,在完全拟合树生成之前停止。但阈值确定是很困难的的,高而不足,低而过度
(2)后剪枝(postpruning):是更为常用的一种方法。它是在决策树构建完成之后,自底而上修剪。

CART使用后剪枝,评判标准是代价复杂度(树叶节点个数与树错误率的函数),我们比较子树剪枝前后的代价复杂度,如果剪枝使得代价复杂度降低则减去该子树,否则保留(因为我们需要的是最小化复杂度的最小决策树)。

三优缺

计算简单,解释性强,可以用图形方式呈现,更为直观;但容易出现过度拟合,随后出现了随机森林,能够减小过度拟合

四 举例(MATLAB)

load fisheriris;t = treefit(meas,species);treedisp(t,'names',{'SL' 'SW' 'PL' 'PW'});load fisheriris;t = classregtree(meas,species,...                 'names',{'SL' 'SW' 'PL' 'PW'})t = Decision tree for classification1  if PL<2.45 then node 2 elseif PL>=2.45 then node 3 else setosa2  class = setosa3  if PW<1.75 then node 4 elseif PW>=1.75 then node 5 else versicolor4  if PL<4.95 then node 6 elseif PL>=4.95 then node 7 else versicolor5  class = virginica6  if PW<1.65 then node 8 elseif PW>=1.65 then node 9 else versicolor7  class = virginica8  class = versicolor9  class = virginicaload fisheriris;t = treefit(meas,species);  sfit = treeval(t,meas);     sfit = t.classname(sfit);   mean(strcmp(sfit,species))  ans =   0.9800sfit = eval(t,meas);pct = mean(strcmp(sfit,species))pct =    0.9800

1 构建决策树

(1)treefit

t = treefit(X,y)t = treefit(X,y,param1,val1,param2,val2,...)

X是n×m 的矩阵,y是目标变量,根据y的数据类型建立分类树或回归数,体现在param1,val1(’method’,’regression’)中。

(2)classregtree

t = classregtree(X,y)t = classregtree(X,y,'Name',value)

2 画出决策树

treedisp,t是已经构建好的模型。

treedisp(t)treedisp(t,param1,val1,param2,val2,...)

3 剪枝

t2 = prune(t1,'level',level)t2 = prune(t1,'nodes',nodes)t2 = prune(t1)

减掉t2中后level层,level=0不剪枝,level=1最底2层,level=2最深2层。
减掉第nodes后所有枝

4 预测

yfit = treeval(t,X)yfit = treeval(t,X,subtrees)[yfit,node] = treeval(...)[yfit,node,cname] = treeval(...)yfit = eval(t,X)yfit = eval(t,X,s)[yfit,nodes] = eval(...)[yfit,nodes,cnums] = eval(...)
原创粉丝点击