机器学习之决策树生成和裁剪

来源:互联网 发布:windows ce6.0模拟器 编辑:程序博客网 时间:2024/06/07 06:11

决策树学习比较典型的有三种算法:ID3 C4.5 CART。

决策树是一种分类预测算法,通过训练样本建立的决策树,能够对未来样本进行分类。

决策树算法包括:建立决策树和裁剪决策树。裁剪决策树是为了减少过拟合带来的错误率。建立决策树的过程,是一种递归分级参考属性的过程,这个过程中会使用参考属性对目标属性的依赖关系。如下面例子,参考属性包括:有房、婚姻、收入。 目标属性:拖欠贷款。

这里写图片描述

ID3和C4.5

ID3和C4.5算法基本流程一致,区别是参考属性分级时,选择的标准不一样。
算法引入了信息论的信息熵,用于定量目标属性的确定性。样本集D中包含m类目标属性样本,每类样本的概率记为p(Ci),则目标属性信息熵定义:

H(D)=i=1mpi(Ci)log2pi(Ci)

当使用参考属性进行分级时,会得到多个样本子集记为D0,D1,..Dk(iDi=D),子集Dj中关于目标属性的信息熵记为H(Dj),样本数为|Dj|,则分级后目标信息熵为:
H(D)=i=0k|Di||D|H(Di)

举例来说,上面的训练样本集中拖欠贷款的信息熵为:
310log310710log710

若使用有房参考属性进行分级,得到2个子集,有房子集样本数为3,没有房子集样本数为7。则分级后关于目标属性的信息熵:
H(D)=310H(D)+710H(D)

ID3算法中,定义G(D)=H(D)H(D)为信息熵增益,表示经过分级后,对目标属性的判断把握性。使用不同的参考属性进行分级,得到的信息增益不一样。

使用不同的分级属性,得到的信息熵增益不同,ID3判定准则是,选择最大的信息增益对应的分级属性进行分级。

不过ID3的这个判定准则有一个缺陷,它总是倾向于属性值种类多的属性。例如上面的样本集,年收入用数字表达式,总类会有很多(数字是连续的)。因此,分级属性总是会倾向于这个参考属性。

为了解决这个问题,C4.5提出了使用信息增益率作为判定准则,增益率定义为:

GainRatio(D)=G(D)SplitInfo(D)

SplitInfo(D)=j|Dj||D|log|Dj||D|

可以看出,分级后的子集数越多,SplitInfo越大,导致增益率越小。
若使用有房参考属性进行分级
SplitInfo(D)=(03log03+33log33)(47log47+37log37)

from math import log#计算信息熵def CalShannonEntropy(dataSet): #格式:参考属性1,参考属性2...,目标属性    sampleNum=len(dataSet)    samplecount={}    for data in dataSet:        currentfeature=data[-1]        if currentfeature not in samplecount.keys(): samplecount[currentfeature]=0        samplecount[currentfeature]+=1    entropy=0.0    for key in samplecount.keys():                p=float(samplecount.get(key))/sampleNum;        entropy-=p*log(p,2)    return entropy;#选择最佳分级特征属性def ChooseBestFeature(dataSet):    baseEntropy=CalShannonEntropy(dataSet) #第一次分级之前的信息熵    samplenum=len(dataSet)    labelnum=len(dataSet[0])-1    entropymax=0.0    bestLabelIndex=-1    for labelIndex in range(labelnum): #遍历所有的参考特征属性        values=set([example[labelIndex] for example in dataSet])        entsum=0.0        splitinfo=0.0        for value in values:             subdataSet=SplitDataSet(dataSet,labelIndex,value)   #分级后得到的子集D1,D2,...            p=float(len(subdataSet))/samplenum             entsum+=p*CalShannonEntropy(subdataSet) #计算分级后的信息熵            splitinfo-=p*log(p,2)                    #计算分级信息SplitInfo,ID3不用计算        infoGainRatio=(baseEntropy-entsum)/splitinfo #计算信息增益率,ID3计算增益就可以了        if infoGainRatio>entropymax:   #判断最大的增益率或增益            entropymax=infoGainRatio            bestLabelIndex=labelIndex    return bestLabelIndex  #返回最大增益或增益率的参考特征属性

CART

参考
除了ID3和C4.5,还有一种算法CART(classification and regression tree)。这是一种可以处理离散特征值和连续特征值的决策树,处理离散特征值使用分类决策树,处理连续特征值使用回归决策树。
CART的分级判定准则常用的是gini指数。gini指数和信息熵类似,gini指数越低,对目标属性判定越有把握。gini定义如下:

gini(D)=1i=1kp2i

经过分级后,得到子样本集D0,...Dk,gini指数定义为:
gini(D)=jk|Dj||D|gini(Dj)

若使用有房参考属性进行分级后,
gini(D)=310(1(03)2(33)2)+710(1(47)2(37)2)

def CalGini(subdataSet):#使用gini标准选择最佳分级特征属性def ChooseBestFeature(dataSet):    samplenum=len(dataSet)    labelnum=len(dataSet[0])-1    ginimax=0.0    bestLabelIndex=-1    for labelIndex in range(labelnum): #遍历所有的参考特征属性        values=set([example[labelIndex] for example in dataSet])        ginisum=0.0        splitinfo=0.0        for value in values:             subdataSet=SplitDataSet(dataSet,labelIndex,value)   #分级后得到的子集D1,D2,...            p=float(len(subdataSet))/samplenum    #子集中样本数占比            ginisum+=p*CalShannonEntropy(subdataSet) #计算分级后的Gini            splitinfo-=p*log(p,2)                    #计算分级信息SplitInfo,ID3不用计算        infoGainRatio=(baseEntropy-entsum)/splitinfo #计算信息增益率,ID3计算增益就可以了        if infoGainRatio>entropymax:   #判断最大的增益率或增益            entropymax=infoGainRatio            bestLabelIndex=labelIndex    return bestLabelIndex  #返回最大增益或增益率的参考特征属性

决策树裁剪

为了防止过拟合,需要对决策树进行裁剪。裁剪分为事前裁剪和事后裁剪。事前裁剪发生在建立决策树时,通过判定规则(例如节点总数>),来决定是否进行新的分级。 事后裁剪发生在建立决策树后,通过判定规则进行树的修剪。常用的事后裁剪方法一般为CCP:代价复杂性修剪法
CCP对决策树中的每个非叶子节点定义了一个表面误差增益值α

α=R(node)R(leaf)|Nleaf|1

R(t)=minlr(|t|)|t|l+|t|r|t|l+|t|r|T|

例如:
这里写图片描述
设样本集样本总数为100
R(T4)=7100,R(T8)=2100,R(T9)=0100,R(T7)=3100,R(T6)=3100

αT4=R(T4)(R(T7)+R(T8)+R(T9))31=2200

αT6=R(T6)(R(T8)+R(T9))31=1100

实验:
训练样本
这里写图片描述
二值图标记
这里写图片描述
建立决策树,完成二值图。
这里写图片描述

import numpy as npfrom skimage import iofrom skimage import colorfrom math import logimport operatorimport matplotlib.pyplot as pltdef calcShannonEnt(dataSet):    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet: #the the number of unique elements and their occurance        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        shannonEnt -= prob * log(prob,2) #log base 2    return shannonEntdef splitDataSet(dataSet, axis, value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting            reducedFeatVec.extend(featVec[axis+1:])            retDataSet.append(reducedFeatVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0; bestFeature = -1    for i in range(numFeatures):        #iterate over all the features        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature        uniqueVals = set(featList)       #get a set of unique values        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)             infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy        if (infoGain > bestInfoGain):       #compare this to the best gain so far            bestInfoGain = infoGain         #if better than current best, set to best            bestFeature = i    return bestFeature                      #returns an integerdef majorityCnt(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):         return classList[0]#stop splitting when all of the classes are equal    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)    return myTree                            #Tree Format#{FeatureLabel1: {Value1_1:{FeatureLabel2:{Value2_1:{},Value2_2:{},Value2_3:{}}},Value1_2:{},Value1_3:{}}}##验证def classify(inputTree,featLabels,testVec):    firstStr = inputTree.keys()[0]  #best label     secondDict = inputTree[firstStr]  #dict    featIndex = featLabels.index(firstStr) #best feature index    key = testVec[featIndex]                #testValue    if key not in secondDict:        return -1    valueOfFeat = secondDict[key]           #    if isinstance(valueOfFeat, dict):         classLabel = classify(valueOfFeat, featLabels, testVec)    else:         classLabel = valueOfFeat    return classLabelplt.figure(figsize(20,10))test_img=io.imread('E:\BaiduYunDownload\ML\project\\flower_test.png')plt.subplot(221)plt.imshow(test_img)mask_img=io.imread('E:\BaiduYunDownload\ML\project\\flower_mask.png')plt.subplot(222)plt.imshow(mask_img,cmap=plt.cm.gray)img_h,img_w,dim=test_img.shapetest_data=floor(test_img.reshape(img_w*img_h,dim)/32)mask_data=mask_img.reshape(img_w*img_h)n_data=test_data[mask_data==0]p_data=test_data[mask_data==255]n_data=np.hstack((n_data,zeros((n_data.shape[0],1))))p_data=np.hstack((p_data,ones((p_data.shape[0],1))))dataSet=np.vstack((n_data,p_data))labels=['R','G','B']print "Create Tree"myTree=createTree(dataSet.tolist(),labels)print "testing"result=np.arange(img_h*img_w)test=test_data.reshape((img_h*img_w,3))index=0lastresult=0labels=['R','G','B']for t in test:    classLabel=classify(myTree,labels,t)    if classLabel==0:        result[index]=0        lastresult=0    elif classLabel==1:        result[index]=255        lastresult=255    else:        result[index]=lastresult    index+=1result=result.reshape((img_h,img_w))plt.subplot(223)plt.imshow(result,cmap=plt.cm.gray)plt.subplotsprint "finished"

若是使用CART算法,则结果变为
这里写图片描述

0 0
原创粉丝点击