【机器学习】决策树

来源:互联网 发布:mac配套手绘板 编辑:程序博客网 时间:2024/06/04 20:39

参考资料:机器学习实战

一、决策树是什么呢?
举个例子:
小明来找你出去玩耍,你想不想去呢?
1)不想–结果:不去
2)想–作业没做完–结果:不去
3)想–作业做完了–结果:去
以图形化的方式呈现,就如下图所示
是否出去玩耍的决策树
这就是一个决策树:
每个节点根据一个条件划分,每个分叉代表着该条件下的值,每个叶子代表一个结论。
更为一般的:
每个节点根据一个属性划分,每个分叉代表一个值,每个叶子代表一个分类。

二、决策树的构造
根据不同的属性划分,可以产生不同的决策树(分叉不同,结果不同)。那么,怎样构造最合理的决策树呢?这里我们引进一个概念:信息增益。
为了了解信息增益的概念,我们首先要知道信息的定义:

信息:l(x)=log2p(xi)
熵:H=i=1np(xi)log2p(xi)
信息增益:G=i=1nl(x)H

决策树的构造就是选择信息增益最大的属性作为分裂属性。

from numpy import *import operatorfrom math import log#计算香农熵def calcShannonEnt(dataSet):    numEntries=len(dataSet)    labelCounts={}    for featVec in dataSet:       currentLabel=featVec[-1]       if currentLabel not in labelCounts.keys():          labelCounts[currentLabel]=0       labelCounts[currentLabel]+=1    shannonEnt=0.0    for key in labelCounts:        prob=float(labelCounts[key])/numEntries        shannonEnt -= prob*log(prob,2)    return shannonEnt#生成数据集def createDataSet():    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]    labels=['no surfacing','flippers']    return dataSet,labels#分割数据集#三个参数分别是数据集,按第几个属性继续划分,该属性为value则为这一类,例如dataSet,1,0#按照上面的例子返回的是第[1]个参数为0的,则为[1,0,'no'],数据集中的第[2]个元素,如果是dataSet,1,1,则返回的是其余四个构成的listdef splitDataSet(dataSet,axis,value):    retDataSet=[]    for featVec in dataSet:        if featVec[axis]==value:           reduceFeatVec=featVec[:axis]            reduceFeatVec.extend(featVec[axis+1:])           retDataSet.append(reduceFeatVec)    return retDataSet#选择最好的属性来分割def chooseBestFeatureToSplit(dataSet):    numFeatures=len(dataSet[0])-1#所有属性的个数    baseEntropy=calcShannonEnt(dataSet)#计算原数据集的熵    bestInfoGain=0.0    bestFeature=-1    for i in range(numFeatures):        featlist=[example[i] for example in dataSet]#获取第i个属性的所有的值        uniqueVals=set(featlist)#转化为集合(实质是为了去重)        newEntropy=0.0        for value in uniqueVals:#按照第i属性的不同的值来划分集合            subDataSet=splitDataSet(dataSet,i,value)            prob=len(subDataSet)/float(len(dataSet))            newEntropy +=prob*calcShannonEnt(subDataSet)#计算按照第i个属性分割的熵        infoGain=baseEntropy-newEntropy#计算信息增益        if(infoGain>bestInfoGain):#求最大的信息增益,并记录下对应的属性的位置           bestInfoGain=infoGain           bestFeature=i    return bestFeature#计算“票数”,很多情况下分割出来的分类并不唯一,这时候需要看那种分类的概率更大,“票数”更多    def majorityCnt(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys():            classCount[vote]=0        classCount[vote] +=1    sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]   #建树def createTree(dataSet,labels):    classList=[example[-1] for example in dataSet]    if classList.count(classList[0])==len(classList):       return classList[0]    if len(dataSet[0])==1:       return majorityCnt(classList)    bestFeat=chooseBestFeatureToSplit(dataSet)#选出最好的分割属性    bestFeatLabel=labels[bestFeat]    myTree={bestFeatLabel:{}}    del(labels[bestFeat])    featValues=[example[bestFeat] for example in dataSet]    uniqueVals=set(featValues)    for value in uniqueVals:        subLabels=labels[:]        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)    return myTreeimport matplotlib.pyplot as plt#画树decisionNode=dict(boxstyle="sawtooth",fc="0.8")leafNode=dict(boxstyle="round4",fc="0.8")arrow_args=dict(arrowstyle="<-")def plotNode(nodeTxt,centerPt,parentPt,nodeType):    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',    xytext=centerPt,textcoords='axes fraction',    va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)#def createPlot():#    fig=plt.figure(1,facecolor='white')#    fig.clf()#    createPlot.ax1=plt.subplot(111,frameon=False)#    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)#    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)#    plt.show()#计算叶子数   def  getNumLeafs(myTree):     numLeafs=0     firstStr=myTree.keys()[0]     secondDict=myTree[firstStr]     for key in secondDict.keys():         if type(secondDict[key]).__name__=='dict':            numLeafs +=getNumLeafs(secondDict[key])         else:            numLeafs +=1     return numLeafs#计算树高    def getTreeDepth(myTree):    maxDepth=0    firstStr=myTree.keys()[0]    secondDict=myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            thisDepth =1+getTreeDepth(secondDict[key])        else:            thisDepth =1        if thisDepth>maxDepth:           maxDepth=thisDepth    return maxDepth#检索树,就是之前createTree建出的树    def retrieveTree(i):    listOftrees=[{'no surfacing':{0: 'no',1:{'flippers':{0:'no',1:'yes'}}}},    {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]    return listOftrees[i]def plotMidtext(cntrPt,parentPt,txtString):    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]    createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):    numLeafs=getNumLeafs(myTree)    depth=getTreeDepth(myTree)    firstStr=myTree.keys()[0]    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)    plotMidtext(cntrPt,parentPt,nodeTxt)    plotNode(firstStr,cntrPt,parentPt,decisionNode)    secondDict=myTree[firstStr]    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':           plotTree(secondDict[key],cntrPt,str(key))        else:           plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW                   plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)           plotMidtext((plotTree.xOff,plotTree.yOff),cntrPt,str(key))    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalDdef createPlot(inTree):    fig=plt.figure(1,facecolor='white')    fig.clf()    axprops=dict(xticks=[],yticks=[])    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)    plotTree.totalW=float(getNumLeafs(inTree))    plotTree.totalD=float(getTreeDepth(inTree))    plotTree.xOff= -0.5/plotTree.totalW    plotTree.yOff= 1.0    plotTree(inTree,(0.5,1.0),'')    plt.show()#分类    def classify(inputTree,featLabels,testVec):    firstStr=inputTree.keys()[0]    secondDict=inputTree[firstStr]    featIndex=featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex]==key:           if type(secondDict[key]).__name__=='dict':               classLabel=classify(secondDict[key],featLabels,testVec)           else:               classLabel=secondDict[key]                  return classLabel    #保存树    def storeTree(inputTree,filename):    import pickle    fw=open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()#获取树    def grabTree(filename):    import pickle    fr=open(filename)    return pickle.load(fr)

另一种分割方式是利用基尼不纯度,简单来说就是从一个数据集中随机选取子项,度量其被错误分类到其他分组里的概率。
集合T包含N个类别的记录,那么其Gini指标就是类pi类别i出现的频率:
gini(T)=1i=1np2i
如果集合T分成m部分N1,N2,...Nm,那么这个分割的Gini就是:
ginisplit(T)=N1Ngini(T1)+...+NmNgini(Tm)

决策数的构造就是选择具有最小Ginisplit的属性为分裂属性(对于每个属性都要遍历所有可能的分隔方法)
这里就不介绍了。


0 0
原创粉丝点击