机器学习实战—决策树(二)

来源:互联网 发布:prisma 喜好知乎 滤镜 编辑:程序博客网 时间:2024/05/02 01:20
#-*-coding:utf-8-*-import chch.set_ch()import matplotlib.pyplot as pltdecisionNode = dict(boxstyle = "sawtooth",fc="0.8")leafNode = dict(boxstyle="round4",fc = "0.8")arrow_args = dict(arrowstyle = "<-")#建立标注annotatedef plotNode(nodeTxt,centerPt,parentPt,nodeType):                            # 标注内容   标注位置                                    标签位置    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords = 'axes fraction',xytext = centerPt,textcoords='axes fraction',va="center",\                                         #标签的格式         箭头的格式                            ha = "center",bbox = nodeType,arrowprops=dict(arrowstyle="<-"))def createPlotTemp():                    #图名,可以是数字  背景颜色                  fig = plt.figure("xihuan",facecolor = 'white')    fig.clf()#clear the figure    createPlotTemp.ax1 = plt.subplot(111,frameon = False)#产生一个子图,不显示坐标轴,但有坐标    plotNode(U'决策节点',(0.5,0.1),(0.1,0.5),decisionNode)    plotNode(U'叶节点',  (0.8,0.1),(0.3,0.8),leafNode)    plt.show()#计算决策树的叶子节点的数目def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key])==dict:            numLeafs += getNumLeafs(secondDict[key])        else: numLeafs += 1    return numLeafs#计算树的深度def getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key])==dict:            thisDepth = 1 + getTreeDepth(secondDict[key])        else: thisDepth = 1;        if thisDepth > maxDepth:maxDepth = thisDepth    return maxDepth#生成一棵决策树def retrieveTree(i):    listOfTrees = [{'no surfacing': {0:'no',1:{'flippers':\                   {0: 'no', 1: 'yes'}}}},                   {'no surfacing': {0:'no',1:{'flippers':\                   {0:{'head':{0: 'no', 1: 'yes'}}, 1: 'no'}}}}                   ]    return listOfTrees[i]def plotMidText(centrPt,parentPt,txtString):    xMid = (parentPt[0]-centrPt[0])/2.0+centrPt[0]    yMid = (parentPt[1]-centrPt[1])/2.0+centrPt[1]    createPlot.ax1.text(xMid,yMid,txtString)def createPlot(inTree):    fig = plt.figure("xihuan",facecolor = 'white')    fig.clf()    axprops = dict(xticks=[],yticks=[])    createPlot.ax1 = plt.subplot(111,frameon = False,**axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xoff = -0.5/plotTree.totalW;plotTree.yoff = 1.0;#为了保证根结点标注于标签位置一致    plotTree(inTree,(0.5,1.0),'')    plt.show()    def plotTree(myTree,parentPt,nodeText):    numleafs = getNumLeafs(myTree)    depth = getTreeDepth(myTree)    firstStr = myTree.keys()[0] #第一个分类特征    centrPt = (plotTree.xoff + (1.0+float(numleafs))/2.0/plotTree.totalW,\               plotTree.yoff)    plotMidText(centrPt,parentPt,nodeText)#显示文本标签信息,根节点为空    plotNode(firstStr,centrPt,parentPt,decisionNode)#打印标注特征信息    secondDict = myTree[firstStr]    plotTree.yoff = plotTree.yoff-1.0/plotTree.totalD#调整下一个子数的Y方向位置    for key in secondDict.keys():        if type(secondDict[key])==dict:            plotTree(secondDict[key],centrPt,str(key))        else:#画出结点即可            plotTree.xoff = plotTree.xoff + 1.0/plotTree.totalW            plotNode(secondDict[key],(plotTree.xoff,plotTree.yoff),centrPt,leafNode)            plotMidText((plotTree.xoff,plotTree.yoff),centrPt,str(key))    plotTree.yoff = plotTree.yoff+1.0/plotTree.totalD#由于递归返回上一层,所以这里返回上层的y分量高度                    mytree = retrieveTree(1)#print getTreeDepth(mytree)createPlot(mytree)

0 0
原创粉丝点击