matPlotLib绘制决策树
来源:互联网 发布:java开发小游戏 编辑:程序博客网 时间:2024/05/29 07:00
上篇中,实现了创建决策树但并不直观,这里学习绘制决策树,便于直观理解。
Matplotlib提供了名为pylab的模块,其中包括了许多numpy和pyplot中常用的函数,方便用户快速进行计算和绘图,
可以用于IPython中的快速交互式使用。
Matplotlib中的快速绘图的函数库可以通过如下语句载入:
- import matplotlib.pyplot as plt
import matplotlib.pyplot as plt绘制树形图,我们需要定义树和叶的形态,还必须要知道有多少个叶节点和判断节点,还有树的层数,这样才能确定树的大小,绘制绘图区
首先注解绘制的树节点和叶节点以及箭头
- #定义文本框和箭头格式
- decisionNode = dict(boxstyle=”sawtooth”, fc=“0.8”) #定义判断节点形态
- leafNode = dict(boxstyle=”round4”, fc=“0.8”) #定义叶节点形态
- arrow_args = dict(arrowstyle=”<-“) #定义箭头
- #绘制带箭头的注解
- #nodeTxt:节点的文字标注, centerPt:节点中心位置,
- #parentPt:箭头起点位置(上一节点位置), nodeType:节点属性
- def plotNode(nodeTxt, centerPt, parentPt, nodeType):
- createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=’axes fraction’,
- xytext=centerPt, textcoords=’axes fraction’,
- va=”center”, ha=“center”, bbox=nodeType, arrowprops=arrow_args )
#定义文本框和箭头格式#绘制带箭头的注解#nodeTxt:节点的文字标注, centerPt:节点中心位置,#parentPt:箭头起点位置(上一节点位置), nodeType:节点属性def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
decisionNode = dict(boxstyle=”sawtooth”, fc=”0.8”) #定义判断节点形态
leafNode = dict(boxstyle=”round4”, fc=”0.8”) #定义叶节点形态
arrow_args = dict(arrowstyle=”<-“) #定义箭头
然后得到叶节点的数目和树的层数
- #计算叶节点数
- def getNumLeafs(myTree):
- numLeafs = 0
- firstStr = myTree.keys()[0]
- secondDict = myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__==‘dict’:#是否是字典
- numLeafs += getNumLeafs(secondDict[key]) #递归调用getNumLeafs
- else: numLeafs +=1 #如果是叶节点,则叶节点+1
- return numLeafs
- #计算数的层数
- def getTreeDepth(myTree):
- maxDepth = 0
- firstStr = myTree.keys()[0]
- secondDict = myTree[firstStr]
- for key in secondDict.keys():
- if type(secondDict[key]).__name__==‘dict’:#是否是字典
- thisDepth = 1 + getTreeDepth(secondDict[key]) #如果是字典,则层数加1,再递归调用getTreeDepth
- else: thisDepth = 1
- #得到最大层数
- if thisDepth > maxDepth:
- maxDepth = thisDepth
- return maxDepth
#计算叶节点数def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#是否是字典 numLeafs += getNumLeafs(secondDict[key]) #递归调用getNumLeafs else: numLeafs +=1 #如果是叶节点,则叶节点+1 return numLeafs#计算数的层数def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#是否是字典 thisDepth = 1 + getTreeDepth(secondDict[key]) #如果是字典,则层数加1,再递归调用getTreeDepth else: thisDepth = 1 #得到最大层数 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth有了注解和计算树形图的位置的参数,就可以绘制树形图了
为了清晰简明,在父子节点之间加入文本标签信息
- #在父子节点间填充文本信息
- #cntrPt:子节点位置, parentPt:父节点位置, txtString:标注内容
- def plotMidText(cntrPt, parentPt, txtString):
- xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
- yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
- createPlot.ax1.text(xMid, yMid, txtString, va=”center”, ha=“center”, rotation=30)
#在父子节点间填充文本信息#cntrPt:子节点位置, parentPt:父节点位置, txtString:标注内容def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)然后绘制树形图;
- #绘制树形图
- #myTree:树的字典, parentPt:父节点, nodeTxt:节点的文字标注
- def plotTree(myTree, parentPt, nodeTxt):
- numLeafs = getNumLeafs(myTree) #树叶节点数
- depth = getTreeDepth(myTree) #树的层数
- firstStr = myTree.keys()[0] #节点标签
- #计算当前节点的位置
- cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
- plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息
- plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注解
- secondDict = myTree[firstStr]
- plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
- for key in secondDict.keys():
- if type(secondDict[key]).__name__==‘dict’:#判断是不是字典,
- plotTree(secondDict[key],cntrPt,str(key)) #递归绘制树形图
- else: #如果是叶节点
- plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
- plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
- plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
- plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
- #创建绘图区
- def createPlot(inTree):
- fig = plt.figure(1, facecolor=‘white’)
- fig.clf()
- axprops = dict(xticks=[], yticks=[])
- createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
- plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度
- plotTree.totalD = float(getTreeDepth(inTree)) #树的深度
- plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
- plotTree(inTree, (0.5,1.0), ”)
- plt.show()
#绘制树形图#myTree:树的字典, parentPt:父节点, nodeTxt:节点的文字标注def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) #树叶节点数 depth = getTreeDepth(myTree) #树的层数 firstStr = myTree.keys()[0] #节点标签 #计算当前节点的位置 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) #在父子节点间填充文本信息 plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制带箭头的注解 secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#判断是不是字典, plotTree(secondDict[key],cntrPt,str(key)) #递归绘制树形图 else: #如果是叶节点 plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#创建绘图区def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度 plotTree.totalD = float(getTreeDepth(inTree)) #树的深度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()其中createPlot()是主函数,创建绘图区,计算树形图的尺寸大小,它调用plotTree()等函数,plotTree()递归画出整个树形图。
加载之前创建了tree模块和这个treeplot模块,在命令提示符下输入
- >>> import treeplot
- >>> import tree
- >>> myDat,labels = tree.createDataSet()
- >>> myTree = tree.createTree(myDat,labels)
- >>> treeplot.createPlot(myTree)
>>> import treeplot>>> import tree>>> myDat,labels = tree.createDataSet()>>> myTree = tree.createTree(myDat,labels)>>> treeplot.createPlot(myTree)得到正确的树形图
用创建的tree模块和treeplot模块,使用决策树预测隐形眼镜类型;
在命令提示符下输入
- >>> import tree
- >>> import treeplot
- >>> fr = open(’lenses.txt’)
- >>> lenses = [inst.strip().split(’\t’) for inst in fr.readlines()]
- >>> lensesLabels = [’age’,‘prescript’,‘astigmatic’,‘tearRate’]
- >>> lensesTree = tree.createTree(lenses,lensesLabels)
- >>> treeplot.createPlot(lensesTree)
>>> import tree>>> import treeplot>>> fr = open('lenses.txt')>>> lenses = [inst.strip().split('\t') for inst in fr.readlines()]>>> lensesLabels = ['age','prescript','astigmatic','tearRate']>>> lensesTree = tree.createTree(lenses,lensesLabels)>>> treeplot.createPlot(lensesTree)得到如下所示
如有不足,请指出,谢谢~~
阅读全文
0 0
- matPlotLib绘制决策树
- 机器学习-Matplotlib绘制决策树
- python —— Matplotlib模块(绘制决策树)
- 数据可视化matplotlib(03) 绘制决策树
- 第三章 决策树 3.2 使用 Matplotlib 注解绘制树形图
- 机器学习实战python版决策树以及Matplotlib注解绘制决策树
- 机器学习实战python版决策树以及Matplotlib注解绘制决策树
- 绘制决策树
- 决策树03——使用matplotlib绘制树形图并测试算法
- matplotlib-绘制精美图表
- matplotlib绘制简单图
- matplotlib 散点图绘制
- 使用matplotlib绘制图表
- matplotlib 绘制三维图
- matplotlib绘制决策边界
- matplotlib绘制基本图形
- matplotlib绘制K线图
- 使用matplotlib绘制散点图
- 灰度比特平面
- python数字字符串左侧补零
- c语言算法
- CCF CSP 201612-1 中间数
- 机器学习-朴素贝叶斯分类代码详解
- matPlotLib绘制决策树
- qt连接sql数据库
- 通信原理作业问题
- 【Scikit-Learn 中文文档】集成方法
- Python内置函数-reduce()函数
- Java.math.BigDecimal.stripTrailingZeros()方法实例
- bzoj1051 受欢迎的牛 结论+tarjan
- 函数 toPlainString() 和 toString()
- Python学习笔记----高阶函数