《机器学习实战》学习笔记<二>决策树

来源:互联网 发布:sql insert date 编辑:程序博客网 时间:2024/05/18 01:05

trees.py

# -*- coding: utf-8 -*- from math import logimport operatordef calcShannonEnt(dataSet):#计算信息熵    numEntries=len(dataSet)    labelsCounts={}#集合总数    for featVec in dataSet:        currentLabel=featVec[-1]#创建数据字典,键值为最后一行数据        if currentLabel not in labelsCounts.keys():#如果当前键值不存在,则将当前键值加入字典            labelsCounts[currentLabel] = 0            labelsCounts[currentLabel] += 1#记录标签发生次数            shannonEnt = 0.0            for key in labelsCounts:                prob = float(labelsCounts[key])/numEntries#发生频率                shannonEnt -= prob*log(prob,2)#信息熵公式                return shannonEntdef createDataSet():#数据集    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]    labels=['no surfacing','flippers']#用0,1表示    return dataSet,labelsdef splitDataSet(dataSet,axis,value):#数据集,数据的特征,特征的值    retDataSet=[]#创建新列表    for featVec in dataSet:#遍历列表元素        if featVec in dataSet:            if featVec[axis]==value:#符合条件的话,添加到新建列表retDataSet                reducedFeatVec=featVec[:axis]#切片操作                reducedFeatVec.extend(featVec[axis+1:])#添加元素                retDataSet.append(reducedFeatVec)#添加列表                return retDataSetdef chooseBestFeatureToSplit(dataSet):#选取特征,划分数据集    numFeature=len(dataSet[0])-1    baseEntropy=calcShannonEnt(dataSet)    bestInfoGain=0.0;bestFeature=-1#保存最初的无序值    for i in range(numFeature):#遍历数据集的特征        featList=[example[i] for example in dataSet]        uniqueVals=set(featList)#将数据集的特征值写入新列表        newEntropy=0.0        for value in uniqueVals:#遍历当前特征中所有的唯一属性值            subDataSet=splitDataSet(dataSet,i,value)            prob=len(subDataSet)/float(len(dataSet))            newEntropy += prob*calcShannonEnt(subDataSet)#计算数据集的新墒值            infoGain=baseEntropy-newEntropy            if(infoGain>bestInfoGain):#比较所有特征值的信息增益                bestInfoGain=infoGain                bestFeature=i                return bestFeature#返回最好的划分def majorityCnt(classList):#创建数据字典    classCount={}#用来储存每个类标签出现的频率    for vote in classList:        if vote not in classCount.keys():            classCount[vote] = 0            classCount[vote] += 1            sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)#排序字典            return sortedClassCount[0][0]#返回出现次数最多的分类名称def createTree(dataSet,labels):#数据集和标签列表    classList=[example[-1] for example in dataSet]#数据集类标签    if classList.count(classList[0])==len(classList):#第一个停止条件所有类标签相同        return classList[0]#返回该标签    if len(dataSet[0])==1:#第二个停止条件,不能分类时,返回出现次数最多的标签        return majorityCnt(classList)    bestFeat=chooseBestFeatureToSplit(dataSet)#创建树,选择最好的存储在bestFeat    bestFeatLabel=labels[bestFeat]    myTree={bestFeatLabel:{}}#存储树的所有信息    del(labels[bestFeat])    featValues=[example[bestFeat] for example in dataSet]    uniqueVals=set(featValues)    for value in uniqueVals:#遍历所有属性值        subLabels=labels[:]        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)        return myTree        def classify(inputTree,featLabels,testVec):    firstStr=inputTree.keys()[0]    secondDict=inputTree[firstStr]    featIndex=featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex]==key:            if type(secondDict[key]).__name__=='dict':                classLabel=classify(secondDict[key],featLabels,testVec)            else:classLabel=secondDict[key]    return classLabel            def storeTree(inputTree,filename):    import pickle    fw=open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    import pickle    fr=open(filename)    return pickle.load(fr)

treesPlotter.py

# -*- coding: utf-8 -*-import matplotlib.pyplot as pltdecisionNode=dict(boxstyle="sawtooth",fc="0.8")#定义文本框和箭头格式leafNode=dict(boxstyle="round4",fc="0.8")arrow_args=dict(arrowstyle="<-")def plotNode(nodeTxt,centerPt,parentPt,nodeType):#绘制带箭头的注解    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)def createPlot():#绘制图形    fig=plt.figure(1,facecolor='white')    fig.clf()    createPlot.ax1=plt.subplot(111,frameon=False)    plotNode('yejuecedian',(0.5,0.1),(0.1,0.5),decisionNode)    plotNode('yejiedian',(0.8,0.1),(0.3,0.8),leafNode)    plt.show()def getNumLeafs(myTree):#叶子节点的数目,但是运行时报错type object 'str' has no attribute '_name_',已解决:改为__name__    numLeafs=0    firstStr=myTree.keys()[0]    secondDict=myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#如果子节点是字典类型则该节点是一个判断节点            numLeafs +=getNumLeafs(secondDict[key])        else: numLeafs +=1    return numLeafsdef getTreeDepth(myTree):#计算树的深度,运行时报错local variable 'firstStr' referenced before assignment    maxDepth=0    firstStr=myTree.keys()[0]    secondDict=myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            thisDepth=1+getTreeDepth(secondDict[key])        else: thisDepth=1        if thisDepth>maxDepth:maxDepth=thisDepth        return maxDepthdef retrieveTree(i):#用于测试树结构    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]    return listOfTrees[i]def plotMidText(cntrPt,parentPt,txtString):    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]    yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]    createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)def plotTree(myTree,parentPt,nodeTxt):#画树    numLeafs=getNumLeafs(myTree)    depth=getTreeDepth(myTree)    firstStr=myTree.keys()[0]    cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #    plotMidText(cntrPt,parentPt,nodeTxt)    plotNode(firstStr,cntrPt,parentPt,decisionNode)    secondDict=myTree[firstStr]    plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            plotTree(secondDict[key],cntrPt,str(key))        else:            plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))    plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD                def createPlot(inTree):    fig=plt.figure(1,facecolor='white')    fig.clf()    axprops=dict(xticks=[],yticks=[])    createPlot.ax1=plt.subplot(111,frameon=False,**axprops)    plotTree.totalW=float(getNumLeafs(inTree))#数的宽度    plotTree.totalD=float(getTreeDepth(inTree))#数的深度    plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;    plotTree(inTree,(0.5,1.0),'')    plt.show()


             
    

命令行测试

>>> import treePlotter>>> reload(treePlotter)<module 'treePlotter' from 'D:\python2.7\treePlotter.pyc'>>>> myTree=treePlotter.retrieveTree(0)>>> treePlotter.createPlot(myTree)



>>> import treePlotter>>> reload(treePlotter)<module 'treePlotter' from 'D:\python2.7\treePlotter.pyc'>>>> myTree=treePlotter.retrieveTree(0)>>> myTree['no surfacing'][3]='maybe'>>> treePlotter.createPlot(myTree)





0 0
原创粉丝点击