写程序学ML:决策树算法原理及实现(三)

来源:互联网 发布:知乎 邮箱注册 编辑:程序博客网 时间:2024/05/20 11:35

[题外话]近期申请了一个微信公众号:平凡程式人生。有兴趣的朋友可以关注,那里将会涉及更多更新机器学习、OpenCL+OpenCV以及图像处理方面的文章。

2.2   决策树的绘制

为了更好地发挥决策树易于理解的优点,我们使用Matplotlib将创建的决策树绘制出来。此处调用函数createPlot()对决策树进行绘制。

实现过程如下:

创建模块DecisionTreePlotter及其存储文件DecisionTreePlotter.py;

 调用函数plt.figure()创建一个figure;

调用函数plt.subplot()在figure中创建一个子图;

调用函数getNumLeafs()获得决策树的叶子树,即树的宽度;

调用函数getTreeDepth()获得决策树的深度;

调用函数plotTree()绘制整棵决策树,最后显示出来。

具体代码如下:

#创建figure并绘制树inTreedef createPlot(inTree):    #Matplotlib 里的常用类的包含关系为 Figure -> Axes -> (Line2D, Text, etc.)    #一个Figure对象可以包含多个子图(Axes),在matplotlib中用Axes对象表示一个绘图区域,可以理解为子图。    fig = plt.figure(1, facecolor = 'white') #定义一个figure对象,背景色设置为全白    fig.clf() #清楚figure中的内容    axprops = dict(xticks = [], yticks = [])     createPlot.ax1 = plt.subplot(111, frameon = False, **axprops) #在图表fig中创建一个子图ax1    plotTree.totalW = float(getNumLeafs(inTree)) #获取样本树的叶子结点数目作为plotTree的宽度    plotTree.totalD = float(getTreeDepth(inTree)) #获取样本树的深度作为plotTree的深度    plotTree.xOff = -0.5 / plotTree.totalW;    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), '') #依次绘制整棵决策树    plt.show()


函数getNumLeafs

该函数定义为:def getNumLeafs(myTree)

用来获取当前树中的叶子结点数目;

函数实现:

获取参数myTree的根结点;

获取根结点下的各个子树;

对各个子树依次循环,如果还是子树,则递归调用函数getNumLeafs()获取子树的叶子结点数目;否则叶子结点数目加1。

最后返回所有的叶子结点数目;

具体代码如下:
#获取当前树中的叶子结点数目def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0] #获取当前树myTree中第一个key,即该树的根节点    secondDict = myTree[firstStr] #获取第一个key对应的内容,即根节点下的子树    for key in secondDict.keys(): #根节点对应的各个分支,依次循环        #type()就是一个最实用又简单的查看数据类型的方法。        #type()是一个内建的函数,调用它就能够得到一个反回值,从而知道想要查询的对像类型信息。        if type(secondDict[key]).__name__ == 'dict': #如果该子树还是一棵树,递归调用函数getNumLeafs(),获取子树的叶子结点数            numLeafs += getNumLeafs(secondDict[key])        else: #如果是叶子结点,则叶子数加1            numLeafs += 1    return numLeafs #返回当前树中叶子结点的个数

函数getTreeDepth

该函数定义为:def getTreeDepth(myTree)

用来获取当前树的最大深度;

函数实现:

获取参数myTree的根结点;

获取根结点下的各个子树;

对各个子树依次循环,如果还是子树,则递归调用函数getTreeDepth ()获取子树的深度并加1;否则为叶子结点,返回1。

判断当前子树是否最深子树;如果是,则更新最大深度信息;

最后返回最大深度信息;

具体代码如下:

#获取当前树的最大深度def getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0] #获取树的根节点    secondDict = myTree[firstStr] #获取树的子树    for key in secondDict.keys(): #根节点对应的各个分支,依次循环        #如果该子树还是一棵树,递归调用函数getTreeDepth(),获取子树的深度        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else: #如果是叶子结点,则返回1            thisDepth = 1        if thisDepth > maxDepth: #更新最大深度变量值            maxDepth = thisDepth    return maxDepth #返回最大深度

函数plotTree

该函数定义为:def plotTree(myTree, parentPt, nodeTxt)

用来绘制决策树myTree;

函数实现:

调用函数getNumLeafs()获取叶子结点数目;

调用函数getTreeDepth()获取决策树最大深度;

获取根结点;

调用函数plotMidText()绘制文本信息nodeTxt;

调用函数plotNode()绘制结点;

获取各个子树,依次循环:如果还是子树,则递归调用函数plotTree()继续绘制子树;否则,为叶子结点,调用函数plotNode()绘制叶子结点;

具体代码如下:

def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree) #获取样本树的叶子结点数目    depth = getTreeDepth(myTree) #获取样本树的深度    firstStr = myTree.keys()[0] #获取样本树的根结点    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, \              plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt) #绘制文本信息nodeTxt    plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制根结点    secondDict = myTree[firstStr] #获取各个子树    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD    for key in secondDict.keys(): #循环遍历各个子树        if type(secondDict[key]).__name__ == 'dict': #如果包含的是子树,递归调用plotTree绘制结点                plotTree(secondDict[key], cntrPt, str(key))        else:                plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW                plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制叶子结点                plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) #绘制文字注释    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD #调整Y轴的坐标值

函数plotMidText

该函数定义为:def plotMidText(cntrPt, parentPt, txtString)

用来显示文本信息txtString;

函数实现:

计算好x,y坐标后,调用函数createPlot.ax1.text()完成文本的显示。

具体代码如下:

#显示文本,在坐标点cntrPt和parentPt连接线上的中点,显示文本txtStringdef plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] #计算x坐标    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] #计算y坐标    createPlot.ax1.text(xMid, yMid, txtString) #在(xMid, yMid)处显示txtString


函数plotNode

该函数定义为:def plotNode(nodeTxt, centerPt, parentPt, nodeType)

绘制一个结点,nodeTxt为结点显示文本,centerPt为文本起始位置,parentPt为箭头的起始位置,nodeType为结点框的样式;

函数实现:调用函数createPlot.ax1.annotate()直接绘制结点,需要设置好相关参数。

具体代码如下:

#绘制一个结点,nodeTxt为结点显示文本,centerPt为文本起始位置,parentPt为箭头的起始位置,nodeType为结点框的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType):    #使用annotate()方法可以很方便地添加文字注释    # 第一个参数是注释的内容      # xy设置箭头尖的坐标      # xytext设置注释内容显示的起始位置      # arrowprops 用来设置箭头样式    # bbox用来设置节点框的样式    # xycoords and textcoords 是坐标xy与xytext的说明,若textcoords=None,则默认textNone与xycoords相同,若都未设置,默认为data    # va/ha设置节点框中文字的位置,va取值为(u'top', u'bottom', u'center', u'baseline'),ha取值为(u'center', u'right', u'left')    createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',                            xytext = centerPt, textcoords = 'axes fraction',                            va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)
(未完待续)



阅读全文
0 0