《MachineLearningInAction》之绘制决策树

来源:互联网 发布:python 自动登录签到 编辑:程序博客网 时间:2024/06/07 11:36

《MachineLearningInAction》(Peter Harrington)中的代码有点小问题,我重写了全书所有代码,分享于此。


Block Ⅰ

import matplotlib.pyplot as plt #用于调用绘图import matplotlib #用于调用rcParams属性,设置绘图窗口风格

Block Ⅱ

定义全局变量
decisionNode #决策节边框样式
leafNode #叶子节点边框样式
arrow_args #箭头样式


Block Ⅲ

定义函数

retrieveTree #为简化问题及做函数测试,手动生成大小不一的Tree
getNumLeafs #获取Tree的叶子数
getTreeDepth #获取Tree的深度,即decisionNode的个数
plotNode #绘制节点,通过nodeType参数区分decisionNode及leafNode
plotMidText #annotate每一个dict的key
plotTree #迭代绘制决策树

Block Ⅴ

测试代码

__name__ == '__main__'时,即作为主模块调用时执行


效果图



关键思路:

1、迭代生成整棵树,代码测试时从只有一个decisionNode开始(通过调用retrieveTree(0)获得)。

2、通过参数传递plot axis来在同一个轴上绘图。Peter通过在实时调用时给plotTree函数增加axis属性达到同样效果,稍显复杂。

3、整张图绘制在(0,0),(1,1)围成的矩形区域内绘制,第一个decisionNode中心位于(0.5,1),通过decisionNode及LeafNode的数目控制纵向及横向间距,此二者皆为定值。这一点不知道是否与Peter的思路一致,因为他的代码太令我眼花缭乱,没看,我全部重写的。

4、plotTree先绘制decisionNode及其于parentNode之间的箭头,特别地当中心坐标与父节点坐标相等时,系统函数不绘制箭头。


treePlotter源码如下:

# -*- coding: utf-8 -*-"""treePlotter.py~~~~~~~~~~A module with functions to plot decision tree.Created on Thu Mar 23 17:26:57 2017Run on Python 3.6@author: Luo Shaozhuorefer to 'MachineLearninginAction'"""#==============================================================================# import#==============================================================================import matplotlib.pyplot as pltimport matplotlib#==============================================================================# Global variables#==============================================================================decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-") #==============================================================================# functions#==============================================================================def retrieveTree(i=0):    """    return a predefined tree    ~~~~~~~~~~    i: must be 0 or 1. 1 for a taller tree    ~~~~~~~~~~    dictTree    """    listOfTrees =[{'no surfacing': {0: 'no', 1: 'yes'}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                  ]    return listOfTrees[i]def getNumLeafs(dictTree):    """    return the number of leafs    ~~~~~~~~~~    dictTree: a dictonary dipicting a decidion tree    ~~~~~~~~~~    nNumLeaf: number of leafs    """    nNumLeaf = 0    for key in dictTree.keys():        if type(dictTree[key]) == dict:            nNumLeaf += getNumLeafs(dictTree[key])        else:   nNumLeaf +=1    return nNumLeafdef getTreeDepth(dictTree):    """    return the tree depth    ~~~~~~~~~~    dictTree: a dictonary dipicting a decidion tree    ~~~~~~~~~~    nMaxDepth: tree depth    """    nMaxDepth = 0    keys = list(dictTree.keys())[0]    dictTrunk = dictTree[keys]    for key in dictTrunk.keys():        if type(dictTrunk[key]) == dict:            nCurDepth = 1 + getTreeDepth(dictTrunk[key])        else:            nCurDepth = 1        if nCurDepth > nMaxDepth:            nMaxDepth = nCurDepth    return nMaxDepthdef plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, nodeType):    """    plot a decision node or a leaf node depend on nodeType.    ~~~~~~~~~~    pltAxis: plot axis    strNodeTxt: text in node box    tplCntrPt: center coordinates of box    tplPrntPt: starting coordinates of arrow    nodeType: leafNode or decisionNode    ~~~~~~~~~~    N/A    """    pltAxis.annotate(strNodeTxt, xy=tplPrntPt, xycoords='axes fraction',    xytext=tplCntrPt, textcoords='axes fraction',    va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(pltAxis, cntrPt, parentPt, txtString):    """    add feature value in the middle of arrow    ~~~~~~~~~~    cntrPt:    parentPt:    txtString:    ~~~~~~~~~~    N/A    """    xMid = (parentPt[0]+cntrPt[0])/2.0    yMid = (parentPt[1]+cntrPt[1])/2.0    pltAxis.text(xMid, yMid, txtString)def plotTree(dictTree, pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt, strNodeTxt):    """    plot tree recursivly    ~~~~~~~~~~    dictTree: decision tree    pltAxis: axis used for plotting    fTrunkLen: difference of y coordinates between two decision nodes    fBrchLen: difference of y coordinates between two leafs    tplCntrPt: coordinates of parent node      strNodeTxt: text in node box    ~~~~~~~~~~    N/A    """    #plot root node    plotNode(pltAxis, strNodeTxt, tplCntrPt, tplPrntPt, decisionNode)    #plot branch node    tplPrntPt = tplCntrPt    nNumKey = len(dictTree.keys())    fMean = sum([x for x in range(nNumKey)])/nNumKey    for i,key in enumerate(dictTree.keys()):        tplCntrPt = (tplPrntPt[0]+(i-fMean)*fBrchLen, tplPrntPt[1]-fTrunkLen)        plotMidText(pltAxis, tplCntrPt, tplPrntPt, key)        if type(dictTree[key]) == dict:            strNodeTxt = list(dictTree[key].keys())[0]            plotTree(dictTree[key][strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)        else:            strNodeTxt = dictTree[key]            plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, leafNode)if __name__ == '__main__':    dictTree = retrieveTree(2)    matplotlib.rcParams['toolbar'] = 'none'    pltAxis = plt.subplot(111, frameon=False,xticks=[], yticks=[])    fBrchLen = 1/getNumLeafs(dictTree)    fTrunkLen= 1/getTreeDepth(dictTree)    tplCntrPt = (0.5,1)    tplPrntPt = tplCntrPt    strNodeTxt = list(dictTree.keys())[0]    plotTree(dictTree[strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)


0 0
原创粉丝点击