决策树算法学习笔记(三)-预测隐形眼镜类型

来源:互联网 发布:mcafree是什么软件 编辑:程序博客网 时间:2024/04/29 10:26
#coding=utf-8import matplotlib.pyplot as plt#定义文本框和箭头格式decisionNode=dict(boxstyle="sawtooth",fc="0.8")leafNode = dict(boxstyle="round4",fc="0.8")arrow_args=dict(arrowstyle="<-")def getNumLeafs(myTree):    numLeafs=0    firstStr=myTree.keys()[0]    secondDict=myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            numLeafs+=getNumLeafs(secondDict[key])        else:            numLeafs+=1    return numLeafsdef getTreeDepth(myTree):    maxDepth=0    firstStr=myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            thisDepth=1+getTreeDepth(secondDict[key])        else:            thisDepth=1        if thisDepth>maxDepth:maxDepth=thisDepth    return maxDepthdef plotMidText(cntrPt,parentPt,txtString):    xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.axl.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):     numLeafs=getNumLeafs(myTree)     depth=getTreeDepth(myTree)     firstStr=myTree.keys()[0]     cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)     plotMidText(cntrPt,parentPt,nodeTxt)     plotNode(firstStr,cntrPt,parentPt,decisionNode)     secondDict=myTree[firstStr]     plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD     for key in secondDict.keys():         if type(secondDict[key]).__name__=='dict':             plotTree(secondDict[key],cntrPt,str(key))         else:             plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW             plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)             plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))     plotTree.yOff=plotTree.yOff+1.0/plotTree.totalDdef createPlot(inTree):    fig=plt.figure(1,facecolor='white')    fig.clf()    axprops=dict(xticks=[],yticks=[])    createPlot.axl=plt.subplot(111,frameon=False,**axprops)    plotTree.totalW=float(getNumLeafs(inTree))    plotTree.totalD=float(getTreeDepth(inTree))    plotTree.xOff=-0.5/plotTree.totalW;    plotTree.yOff=1.0    plotTree(inTree,(0.5,1.0),'')    plt.show()def plotNode(nodeTxt,centerPt,parentPt,nodeType):    createPlot.axl.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',\                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)# def createPlot():#     fig=plt.figure(1,facecolor='white')#     fig.clf()#     createPlot.axl=plt.subplot(111,frameon=False)#     plotNode('决策节点',(0.5,0.1),(0.1,0.5),decisionNode)#     plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)#     plt.show()def retrieveTree(i):    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\                 { 'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}                 ]    return listOfTrees[i]# myTree=retrieveTree(0)# myTree['no surfacing'][3]='maybe'# print myTree# createPlot(myTree)# print retrieveTree(0)# print  getNumLeafs(myTree=retrieveTree(0))# print getTreeDepth(myTree=retrieveTree(0))
def fileReading(filename):    fr=open(filename)    dataSet=fr.readlines()    lenses=[inst.strip().split('\t') for inst in dataSet]    lensesLabels=['age','prescript','astigmatic','tearRate']    return  lenses,lensesLabelsdataSet,labels=fileReading('lenses.txt')lensesTree=createTree(dataSet,labels)print  lensesTreetreePlotter.createPlot(lensesTree)


阅读全文
0 0
原创粉丝点击