matPlotLib绘制决策树

来源:互联网 发布:java开发小游戏 编辑:程序博客网 时间:2024/05/29 07:00

上篇中,实现了创建决策树但并不直观,这里学习绘制决策树,便于直观理解。

Matplotlib提供了名为pylab的模块,其中包括了许多numpy和pyplot中常用的函数,方便用户快速进行计算和绘图,

可以用于IPython中的快速交互式使用。

Matplotlib中的快速绘图的函数库可以通过如下语句载入:

[python] view plain copy
print?
  1. import matplotlib.pyplot as plt  
import matplotlib.pyplot as plt
绘制树形图,我们需要定义树和叶的形态,还必须要知道有多少个叶节点和判断节点,还有树的层数,这样才能确定树的大小,绘制绘图区

首先注解绘制的树节点和叶节点以及箭头

[python] view plain copy
print?
  1. #定义文本框和箭头格式  
  2. decisionNode = dict(boxstyle=”sawtooth”, fc=“0.8”#定义判断节点形态  
  3. leafNode = dict(boxstyle=”round4”, fc=“0.8”#定义叶节点形态  
  4. arrow_args = dict(arrowstyle=”<-“#定义箭头  
  5.   
  6. #绘制带箭头的注解  
  7. #nodeTxt:节点的文字标注, centerPt:节点中心位置,  
  8. #parentPt:箭头起点位置(上一节点位置), nodeType:节点属性  
  9. def plotNode(nodeTxt, centerPt, parentPt, nodeType):  
  10.     createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords=’axes fraction’,  
  11.              xytext=centerPt, textcoords=’axes fraction’,  
  12.              va=”center”, ha=“center”, bbox=nodeType, arrowprops=arrow_args )  
#定义文本框和箭头格式 
decisionNode = dict(boxstyle=”sawtooth”, fc=”0.8”) #定义判断节点形态
leafNode = dict(boxstyle=”round4”, fc=”0.8”) #定义叶节点形态
arrow_args = dict(arrowstyle=”<-“) #定义箭头

#绘制带箭头的注解#nodeTxt:节点的文字标注, centerPt:节点中心位置,#parentPt:箭头起点位置(上一节点位置), nodeType:节点属性def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

然后得到叶节点的数目和树的层数

[python] view plain copy
print?
  1. #计算叶节点数  
  2. def getNumLeafs(myTree):  
  3.     numLeafs = 0  
  4.     firstStr = myTree.keys()[0]   
  5.     secondDict = myTree[firstStr]   
  6.     for key in secondDict.keys():  
  7.         if type(secondDict[key]).__name__==‘dict’:#是否是字典  
  8.             numLeafs += getNumLeafs(secondDict[key]) #递归调用getNumLeafs  
  9.         else:   numLeafs +=1 #如果是叶节点,则叶节点+1  
  10.     return numLeafs  
  11.   
  12. #计算数的层数  
  13. def getTreeDepth(myTree):  
  14.     maxDepth = 0  
  15.     firstStr = myTree.keys()[0]  
  16.     secondDict = myTree[firstStr]  
  17.     for key in secondDict.keys():  
  18.         if type(secondDict[key]).__name__==‘dict’:#是否是字典  
  19.             thisDepth = 1 + getTreeDepth(secondDict[key]) #如果是字典,则层数加1,再递归调用getTreeDepth  
  20.         else:   thisDepth = 1  
  21.         #得到最大层数  
  22.         if thisDepth > maxDepth:  
  23.             maxDepth = thisDepth  
  24.     return maxDepth  
#计算叶节点数def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]     secondDict = myTree[firstStr]     for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#是否是字典            numLeafs += getNumLeafs(secondDict[key]) #递归调用getNumLeafs        else:   numLeafs +=1 #如果是叶节点,则叶节点+1    return numLeafs
#计算数的层数def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#是否是字典 thisDepth = 1 + getTreeDepth(secondDict[key]) #如果是字典,则层数加1,再递归调用getTreeDepth else: thisDepth = 1 #得到最大层数 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth有了注解和计算树形图的位置的参数,就可以绘制树形图了

为了清晰简明,在父子节点之间加入文本标签信息

[python] view plain copy
print?
  1. #在父子节点间填充文本信息  
  2. #cntrPt:子节点位置, parentPt:父节点位置, txtString:标注内容  
  3. def plotMidText(cntrPt, parentPt, txtString):  
  4.     xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]  
  5.     yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]  
  6.     createPlot.ax1.text(xMid, yMid, txtString, va=”center”, ha=“center”, rotation=30)  
#在父子节点间填充文本信息
#cntrPt:子节点位置, parentPt:父节点位置, txtString:标注内容def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)然后绘制树形图;

[python] view plain copy
print?
  1. #绘制树形图  
  2. #myTree:树的字典, parentPt:父节点, nodeTxt:节点的文字标注  
  3. def plotTree(myTree, parentPt, nodeTxt):  
  4.     numLeafs = getNumLeafs(myTree)  #树叶节点数  
  5.     depth = getTreeDepth(myTree)    #树的层数  
  6.     firstStr = myTree.keys()[0]     #节点标签  
  7.     #计算当前节点的位置  
  8.     cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)  
  9.     plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息  
  10.     plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注解  
  11.     secondDict = myTree[firstStr]  
  12.     plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  
  13.     for key in secondDict.keys():  
  14.         if type(secondDict[key]).__name__==‘dict’:#判断是不是字典,  
  15.             plotTree(secondDict[key],cntrPt,str(key))        #递归绘制树形图  
  16.         else:   #如果是叶节点  
  17.             plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  
  18.             plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)  
  19.             plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))  
  20.     plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  
  21.   
  22. #创建绘图区  
  23. def createPlot(inTree):  
  24.     fig = plt.figure(1, facecolor=‘white’)  
  25.     fig.clf()  
  26.     axprops = dict(xticks=[], yticks=[])  
  27.     createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)      
  28.     plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度  
  29.     plotTree.totalD = float(getTreeDepth(inTree)) #树的深度  
  30.     plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;  
  31.     plotTree(inTree, (0.5,1.0), )  
  32.     plt.show()  
#绘制树形图
#myTree:树的字典, parentPt:父节点, nodeTxt:节点的文字标注def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) #树叶节点数 depth = getTreeDepth(myTree) #树的层数 firstStr = myTree.keys()[0] #节点标签 #计算当前节点的位置 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息 plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注解 secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#判断是不是字典, plotTree(secondDict[key],cntrPt,str(key)) #递归绘制树形图 else: #如果是叶节点 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#创建绘图区def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度 plotTree.totalD = float(getTreeDepth(inTree)) #树的深度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()其中createPlot()是主函数,创建绘图区,计算树形图的尺寸大小,它调用plotTree()等函数,plotTree()递归画出整个树形图。

加载之前创建了tree模块和这个treeplot模块,在命令提示符下输入

[python] view plain copy
print?
  1. >>> import treeplot  
  2. >>> import tree  
  3. >>> myDat,labels = tree.createDataSet()  
  4. >>> myTree = tree.createTree(myDat,labels)  
  5. >>> treeplot.createPlot(myTree)  
>>> import treeplot>>> import tree>>> myDat,labels = tree.createDataSet()>>> myTree = tree.createTree(myDat,labels)>>> treeplot.createPlot(myTree)
得到正确的树形图


用创建的tree模块和treeplot模块,使用决策树预测隐形眼镜类型;

在命令提示符下输入

[python] view plain copy
print?
  1. >>> import tree  
  2. >>> import treeplot  
  3. >>> fr = open(’lenses.txt’)  
  4. >>> lenses = [inst.strip().split(’\t’for inst in fr.readlines()]  
  5. >>> lensesLabels = [’age’,‘prescript’,‘astigmatic’,‘tearRate’]  
  6. >>> lensesTree = tree.createTree(lenses,lensesLabels)  
  7. >>> treeplot.createPlot(lensesTree)  
>>> import tree>>> import treeplot>>> fr = open('lenses.txt')>>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]>>> lensesLabels = ['age','prescript','astigmatic','tearRate']>>> lensesTree = tree.createTree(lenses,lensesLabels)>>> treeplot.createPlot(lensesTree)
得到如下所示


如有不足,请指出,谢谢~~






原创粉丝点击