《MachineLearningInAction》之绘制决策树
来源:互联网 发布:python 自动登录签到 编辑:程序博客网 时间:2024/06/07 11:36
《MachineLearningInAction》(Peter Harrington)中的代码有点小问题,我重写了全书所有代码,分享于此。
Block Ⅰ
import matplotlib.pyplot as plt #用于调用绘图import matplotlib #用于调用rcParams属性,设置绘图窗口风格
Block Ⅱ
定义全局变量
decisionNode #决策节边框样式
leafNode #叶子节点边框样式
arrow_args #箭头样式
定义函数
retrieveTree #为简化问题及做函数测试,手动生成大小不一的Tree
getNumLeafs #获取Tree的叶子数
getTreeDepth #获取Tree的深度,即decisionNode的个数
plotNode #绘制节点,通过nodeType参数区分decisionNode及leafNode
plotMidText #annotate每一个dict的key
plotTree #迭代绘制决策树
Block Ⅴ
测试代码
当__name__ == '__main__'时,即作为主模块调用时执行
效果图
关键思路:
1、迭代生成整棵树,代码测试时从只有一个decisionNode开始(通过调用retrieveTree(0)获得)。
2、通过参数传递plot axis来在同一个轴上绘图。Peter通过在实时调用时给plotTree函数增加axis属性达到同样效果,稍显复杂。
3、整张图绘制在(0,0),(1,1)围成的矩形区域内绘制,第一个decisionNode中心位于(0.5,1),通过decisionNode及LeafNode的数目控制纵向及横向间距,此二者皆为定值。这一点不知道是否与Peter的思路一致,因为他的代码太令我眼花缭乱,没看,我全部重写的。
4、plotTree先绘制decisionNode及其于parentNode之间的箭头,特别地当中心坐标与父节点坐标相等时,系统函数不绘制箭头。
treePlotter源码如下:
# -*- coding: utf-8 -*-"""treePlotter.py~~~~~~~~~~A module with functions to plot decision tree.Created on Thu Mar 23 17:26:57 2017Run on Python 3.6@author: Luo Shaozhuorefer to 'MachineLearninginAction'"""#==============================================================================# import#==============================================================================import matplotlib.pyplot as pltimport matplotlib#==============================================================================# Global variables#==============================================================================decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-") #==============================================================================# functions#==============================================================================def retrieveTree(i=0): """ return a predefined tree ~~~~~~~~~~ i: must be 0 or 1. 1 for a taller tree ~~~~~~~~~~ dictTree """ listOfTrees =[{'no surfacing': {0: 'no', 1: 'yes'}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i]def getNumLeafs(dictTree): """ return the number of leafs ~~~~~~~~~~ dictTree: a dictonary dipicting a decidion tree ~~~~~~~~~~ nNumLeaf: number of leafs """ nNumLeaf = 0 for key in dictTree.keys(): if type(dictTree[key]) == dict: nNumLeaf += getNumLeafs(dictTree[key]) else: nNumLeaf +=1 return nNumLeafdef getTreeDepth(dictTree): """ return the tree depth ~~~~~~~~~~ dictTree: a dictonary dipicting a decidion tree ~~~~~~~~~~ nMaxDepth: tree depth """ nMaxDepth = 0 keys = list(dictTree.keys())[0] dictTrunk = dictTree[keys] for key in dictTrunk.keys(): if type(dictTrunk[key]) == dict: nCurDepth = 1 + getTreeDepth(dictTrunk[key]) else: nCurDepth = 1 if nCurDepth > nMaxDepth: nMaxDepth = nCurDepth return nMaxDepthdef plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, nodeType): """ plot a decision node or a leaf node depend on nodeType. ~~~~~~~~~~ pltAxis: plot axis strNodeTxt: text in node box tplCntrPt: center coordinates of box tplPrntPt: starting coordinates of arrow nodeType: leafNode or decisionNode ~~~~~~~~~~ N/A """ pltAxis.annotate(strNodeTxt, xy=tplPrntPt, xycoords='axes fraction', xytext=tplCntrPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(pltAxis, cntrPt, parentPt, txtString): """ add feature value in the middle of arrow ~~~~~~~~~~ cntrPt: parentPt: txtString: ~~~~~~~~~~ N/A """ xMid = (parentPt[0]+cntrPt[0])/2.0 yMid = (parentPt[1]+cntrPt[1])/2.0 pltAxis.text(xMid, yMid, txtString)def plotTree(dictTree, pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt, strNodeTxt): """ plot tree recursivly ~~~~~~~~~~ dictTree: decision tree pltAxis: axis used for plotting fTrunkLen: difference of y coordinates between two decision nodes fBrchLen: difference of y coordinates between two leafs tplCntrPt: coordinates of parent node strNodeTxt: text in node box ~~~~~~~~~~ N/A """ #plot root node plotNode(pltAxis, strNodeTxt, tplCntrPt, tplPrntPt, decisionNode) #plot branch node tplPrntPt = tplCntrPt nNumKey = len(dictTree.keys()) fMean = sum([x for x in range(nNumKey)])/nNumKey for i,key in enumerate(dictTree.keys()): tplCntrPt = (tplPrntPt[0]+(i-fMean)*fBrchLen, tplPrntPt[1]-fTrunkLen) plotMidText(pltAxis, tplCntrPt, tplPrntPt, key) if type(dictTree[key]) == dict: strNodeTxt = list(dictTree[key].keys())[0] plotTree(dictTree[key][strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt) else: strNodeTxt = dictTree[key] plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, leafNode)if __name__ == '__main__': dictTree = retrieveTree(2) matplotlib.rcParams['toolbar'] = 'none' pltAxis = plt.subplot(111, frameon=False,xticks=[], yticks=[]) fBrchLen = 1/getNumLeafs(dictTree) fTrunkLen= 1/getTreeDepth(dictTree) tplCntrPt = (0.5,1) tplPrntPt = tplCntrPt strNodeTxt = list(dictTree.keys())[0] plotTree(dictTree[strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)
0 0
- 《MachineLearningInAction》之绘制决策树
- Machine Learning 之简单绘制决策树
- 绘制决策树
- [MachineLearningInAction] - KNN
- 决策树的绘制
- python绘制决策树
- matPlotLib绘制决策树
- 机器学习-Matplotlib绘制决策树
- 2.4决策树之决策树实例
- python —— Matplotlib模块(绘制决策树)
- python笔记 --pydot安装与绘制决策树
- 数据可视化matplotlib(03) 绘制决策树
- 分类算法之决策树
- 决策树之id3算法
- 决策树之数据划分
- 分类算法之决策树
- 决策树之CART算法
- 决策树之信息增益
- 按位与、或、异或等运算方法
- 1004. 成绩排名 (20)-PAT乙级
- <span>,<strong>,<em>标签
- ssm 框架整合出错:Access denied for user 'Administrator'@'localhost' (using password: YES)
- shell中单引号与$()的区别
- 《MachineLearningInAction》之绘制决策树
- POJ-----1664---放苹果---递归 思维
- 2017年网易测试开发实习生笔试题
- Android中的数据持久化(一)
- 并查集——POJ 1182 食物链
- switch case 和 if else 的区别及运行效率
- 访问ValueStack中的数据
- 105:Merge Sorted Array
- 黑盒测试和白盒测试