决策树算法的可视化表达

来源:互联网 发布:java vr开发入门 编辑:程序博客网 时间:2024/05/20 04:50

这一篇接着上一篇博客,由于字典这种数据结构的不清晰性,失去了决策树算法本身的优点,所以我们需要将结果通过树形图来表示出来,采用的是Python中matplotlib库。
首先我们简单测试一下使用matplotlib库来画标注的效果。

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle = 'sawtooth',fc ="0.8" )    #定义树节点的一些特征leafNode = dict(boxstyle='round4',fc="0.8")              #定义叶节点的一些特征arrow_args = dict(arrowstyle='<-')                       #定义箭头的特征def plotNode(nodeTxt,centerPt,parentPt,nodeType):    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',                            xytext=centerPt,textcoords='axes fraction',va="center",                            ha="center",bbox=nodeType,arrowprops=arrow_args)                               #annotate参数说明:nodeTxt是标注内容,xy是标注终点的位置坐标,xytext是标注起点的位置坐标,arrowprops标注箭头属性信息def createPlot():    fig = plt.figure(1,facecolor='white')   生成一个图形,1是名字,facecolor是底色    fig.clf()               #清除图像内容    createPlot.ax1=plt.subplot(111,frameon=False)  #111代表生成几行几列第几个图的意思,例如223,就是生成一个两行两列的子图,你画的是其中的第三个,frameon表示子图是否显示坐标轴线,默认True显示,False不显示。    plotNode('a decision node',(0.5,0.1),(0.1,0.5),decisionNode)#画树节点    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)        #画叶节点    plt.show()

运行结果如图所示。
这里写图片描述

下面就开始正式进入决策树可视化算法的部分,首先对于一棵树,我们需要知道他的深度和宽度,深度可以由树的层数来决定,因为决策树是一个完全n叉树,所以可以由总叶节点的个数来确定树的宽度。下面两个函数就分别计算了决策树的深度与宽度。

def getNumLeafs(myTree):      #计算叶节点数目(树的宽度),采用了递归调用的方法    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            numLeafs+=getNumLeafs(secondDict[key])        else: numLeafs+=1    return numLeafs
def getTreeDepth(myTree):     #计算树的深度,同上采用了递归调用的方法    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes            thisDepth = 1 + getTreeDepth(secondDict[key])            print(thisDepth)        else:   thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepth

该函数的作用是在父节点与子节点之间绘制信息。

def plotMidText(cntrPt, parentPt, txtString):#在父子节点之间绘制文本信息    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

该函数就是本算法中的主要部分,绘制树。主要思想在下文中有提及。

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on    numLeafs = getNumLeafs(myTree)                    #叶节点个数,决定了x轴上的宽度    depth = getTreeDepth(myTree)                      #树的深度,决定了y轴上的宽度    firstStr = myTree.keys()[0]                       #根节点的标签    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,        plotTree.yOff)                                #见注释    plotMidText(cntrPt, parentPt, nodeTxt)            #绘制父子节点之间的文本信息    plotNode(firstStr, cntrPt, parentPt, decisionNode)#绘制树节点,firstStr是该点的标签值,cntrPt是子节点的位置坐标,parentPt是父节点位置坐标    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD#将跟踪点的y坐标向下移动一格    for key in secondDict.keys():                      #遍历secondDict的取值        if type(secondDict[key]).__name__=='dict':     #检查此处是不是dict,如果是则此处是树节点,若不是则此处是叶节点            plotTree(secondDict[key],cntrPt,str(key))  #递归调用plotTree        else:                                          #此处是叶节点,则画出叶节点            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW  #将x坐标向右移动一格            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)                                      #画出此处的节点            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#画出父子节点间的文本信息def createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    #控制坐标的显示    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #不显示坐标    #createPlot.ax1 = plt.subplot(111, frameon=False)           #显示坐标    plotTree.totalW = float(getNumLeafs(inTree))       #计算叶节点数目并赋给totalW    plotTree.totalD = float(getTreeDepth(inTree))      #计算树的深度并赋给totalD    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;#给定xOff的初始值    plotTree(inTree, (0.5,1.0), '')    plt.show()

注释:在这里解释一下这个式子的由来,在下面列了两种情形,通过图解的方式来解释一下这个式子。
这里写图片描述
先来看这样一种情形,当我们画好叶节点1准备递归调用plotTree生成根节点2时,我们首先需要确定根节点2的坐标,从图上可以看出根节点2的横坐标与叶节点1的横坐标之间相差2.5个1/totalW,即目前(叶节点个数/2)/totalW + 0.5*(1/totalW),即根节点2的坐标为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW。
如果当前叶节点为奇数个,则为下面一种情形。
这里写图片描述
从图上可以看出根节点2的横坐标与叶节点1的横坐标之间相差2个1/totalW,即目前(叶节点个数/2)/totalW + 0.5*(1/totalW),即根节点2的坐标为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW。

决策树可视化算法的本质就是二叉树的实现,而且决策树本身是一个完全n叉树,首先计算出决策树的叶节点的数目(即决策树的宽度)和树的层数(即深度)。在画决策树的过程中采用两个变量xOff和yOff来跟踪当前位置,遇到树节点就递归调用plotTree函数,遇到叶节点就将xOff向右移一格并画出节点。

{‘no surfacing’: {0: ‘no’, 1: {‘flippers’: {0:{‘old’:{0: ‘no’,1: ‘yes’}} , 1: {‘new’:{0: ‘no’,1: ‘yes’}}}}}}
对于这样的一棵决策树,运行上面代码可以得到下图。
这里写图片描述

原创粉丝点击