决策树03——使用matplotlib绘制树形图并测试算法

来源:互联网 发布:js select 不选中 编辑:程序博客网 时间:2024/06/09 06:55

决策树02——决策树的构建中,我们将已经进行分类的数据存储在字典中,然而字典的表示形式非常不直观,也不容易理解,所以我们将字典中的信息绘制成树形图。

Matplotlib注解功能

  Matplotlib提供一个注解工具annotations,它可以在数据图形上添加文本注释。

  以下将使用Matplotlib的注解功能绘制树形图,它可以对文字着色,并提供多种形状以供选择,而且我们还可以反转箭头,将它指向文本框而不是数据点。

  新建名为treeplotter.py的新文件,将输入下面的程序代码:

# -*-coding=utf-8 -*-#使用文本朱姐绘制树节点import matplotlib.pyplot as plt#定义文本框和箭头格式#定义决策树决策结果的属性(决策节点or叶节点),用字典来定义#下面的字典定义也可以写作 decisionNode = {boxstyle:’sawtooth‘,fc=’0.8‘}decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")       #决策节点,boxstyle为文本框类型,sawtooth是锯齿形,fc是边框内填充的颜色leafNode = dict(boxstyle = "round",fc="0.8")                #叶节点,定义决策树的叶子结点的描述属性arrow_args = dict(arrowstyle = "<-")                         #箭头格式#绘制带箭头的注释def plotNode(nodeTxt,centerPt,parentPt,nodeType):           #nodeTxt是显示的文本,centerPt是文本的中心点,parentPt是箭头的起点坐标,nodeType是一个字典 注解的形状    createPlot.ax1.annotate(nodeTxt,xy = parentPt, xycoords = 'axes fraction',  #xy为箭头的起始坐标,0,0 is lower left of axes and 1,1 is upper right                            xytext = centerPt,textcoords = 'axes fraction', #xytext为注解内容的坐标                            va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args) #bbox注解文本框的形状,arrowprops是指箭头的形状def createPlot():    fig = plt.figure(1,facecolor='white')  #类似于matlab的figure,定义一个画布,其背景为白色    fig.clf()                 #把画布清空    createPlot.ax1 = plt.subplot(111,frameon=False) # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图,    plotNode(U'决策节点',(0.5,0.1),(0.1,0.5), decisionNode)    plotNode(U'叶节点',(0.8,0.1),(0.3,0.8), leafNode)    plt.show()

注意:以上程序运行时会出现中文变成小方框的现象,将以下几行代码添加到文件的开始处。

from pylab import *mpl.rcParams['font.sans-serif'] = ['SimHei']  #指定默认字体mpl.rcParams['axes.unicode_minus'] = False

在命令行输入:

In[70]: import treePlotterBackend TkAgg is interactive backend. Turning interactive mode on.In[71]: treePlotter.createPlot()

这里写图片描述

构造注解树

  我们虽然有x, y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定x轴的长度,我们还需要知道树有多少层,以便可以正确的确定y轴的高度。
  这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的输煤和树的层数。将下面的两个函数添加到treePlotter.py文件中。

#获取叶节点的数目和树的层次def getNumLeafs(myTree):    numLeaf = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ =='dict':         #测试节点的数据类型是否为字典 ,type(secondDict[key]) ==dict 也是可以的            numLeaf += getNumLeafs(secondDict[key])        else: numLeaf += 1    return numLeafdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else: thisDepth = 1        if thisDepth > maxDepth :maxDepth = thisDepth    return maxDepth

函数retrieveTree()输出预先存储的树信息,将 下面代码添加到文件treePlotter.py中:

def retrieveTree(i):    listOfTrees = [{'no surfacing':{0:'0',1:{'flippers':{0:'no',1:'yes'}}}},                   {'no surfacing': {0: '0', 1: {'flippers': {0: {'head':{0:'no',1:'yes'}}, 1: 'no'}}}}                   ]    return listOfTrees[i]

在命令行中输入:

In[2]: import treePlotterBackend TkAgg is interactive backend. Turning interactive mode on.In[3]: treePlotter.retrieveTree(0)Out[3]: {'no surfacing': {0: '0', 1: {'flippers': {0: 'no', 1: 'yes'}}}}In[4]: myTree = treePlotter.retrieveTree(0)In[5]: treePlotter.getNumLeafs(myTree)Out[5]: 3In[6]: treePlotter.getTreeDepth(myTree)Out[6]: 2

将下面代码添加到treePlotter.py中,注意前面已经定义了createPlot(),此时我们需要更新前面的代码。

#plotTree函数#在父子节点间填充文本信息def plotMidText(cntrPt,parentPt,txtString):    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid,yMid,txtString)#自顶向下作图,绘制图形的x轴有效范围是0.0~1.0, y轴有效范围也是0.0~1.0def plotTree(myTree,parentPt,nodeTxt):    numLeafs = getNumLeafs(myTree)    #secondDict[key]的叶节点的数量    depth = getTreeDepth(myTree)      #secondDict[key]的树深度    print 'numLeafs,depth:',numLeafs,',',depth    firstStr = myTree.keys()[0]    # 全局变量plotTree.totalW 存储树的宽度,全局变量PlotTree.totalD 存储树的深度,使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。    cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #注释1    #标记子节点属性    plotMidText(cntrPt,parentPt,nodeTxt)        #这一次循环中的cntrPt(即上式)为cbtrPt,parentPt为上一轮计算出的cntrPt    plotNode(firstStr,cntrPt,parentPt,decisionNode)  #因还没画到叶节点,所以这里画的是决策节点,即此时筛选secondDict[key]还是字典    secondDict = myTree[firstStr]    #计算下一轮要用的y    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    #下面的循环中要使用的y    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key],cntrPt,str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)            plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #注释2def createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[],yticks=[])    #创建一个型为{'xticks': [], 'yticks': []}的字典    createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)    plotTree.totalW = float(getNumLeafs(inTree))      #树的宽度    plotTree.totalD = float(getTreeDepth(inTree))     #输的深度    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #定义节点位置的初始值    plotTree(inTree,(0.5,1.0),'')     #(0.5,1.0)为初始化parentPt的值,注释3    plt.show()

在命令行输入:

In[35]: reload(treePlotter)Out[35]: <module 'treePlotter' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/treePlotter.py'>In[36]: myTree = treePlotter.retrieveTree(0)In[37]: myTreeOut[37]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}In[38]: treePlotter.createPlot(myTree)numLeafs,depth: 3 , 2numLeafs,depth: 2 , 1

这里写图片描述

注释:
1.cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
  在这行代码中,首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:
这里写图片描述
  其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,

  plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW.

2. plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
这行代码中是需要的,当分支最后一个不是字典的时候,字典循环完需要返回上一层继续进行函数
例如:

In[40]: myTree['no surfacing'][3] = 'maybe'In[41]: myTreeOut[41]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}In[42]: treePlotter.createPlot(myTree)numLeafs,depth: 4 , 2numLeafs,depth: 2 , 1

这里写图片描述
 
3.plotTree(inTree,(0.5,1.0),'')
在这行代码中,对于plotTree函数参数赋值为(0.5, 1.0),因为开始的根节点并不用划线,因此父节点和当前节点的位置需要重合,利用2中的确定当前节点的位置便为(0.5, 1.0)

总结:利用这样的逐渐增加x的坐标,以及逐渐降低y的坐标能能够很好的将树的叶子节点数和深度考虑进去,因此图的逻辑比例就很好的确定了,这样不用去关心输出图形的大小,一旦图形发生变化,函数会重新绘制,但是假如利用像素为单位来绘制图形,这样缩放图形就比较有难度了

测试和存储分类器

程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型。

#使用决策树的分类算法def classify(inputTree,featLabels,testVec):    #testVec即为需要分类的数据    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)        #将标签字符串转换为索引    print featIndex    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__ == 'dict':                classLabel = classify(secondDict[key],featLabels,testVec)            else:                classLabel = secondDict[key]    return classLabel

在命令行输入:

In[19]: reload(trees)Out[19]: <module 'trees' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/trees.py'>In[20]: myDat,labels = trees.createDataSet()In[21]: myDatOut[21]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]In[22]: labelsOut[22]: ['no surfacing', 'flippers']In[23]: myTree = trees.createTree(myDat,labels)最好的特征,最好的信息增益: 0 , 0.419973094022最好的特征,最好的信息增益: 0 , 0.918295834054In[24]: myDatOut[24]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]In[25]: labelsOut[25]: ['flippers']In[26]: myDat,labels = trees.createDataSet()In[27]: trees.classify(myTree,labels,[1,1])01Out[27]: 'yes'In[28]: trees.classify(myTree,labels,[1,0])01Out[28]: 'no'

决策树的存储

为了节省时间,最好能够在每次执行分类时调用已经构造好的决策树,使用Python的pickle模块可以在磁盘上保存对象,并在需要的时候读取出来。

#使用pickle模块存储决策树def storeTree(inputTree,filename):    import pickle    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)

在命令行中输入:

In[29]: trees.storeTree(myTree,'classifierStorage.txt')In[30]: trees.grabTree('classifierStorage.txt')Out[30]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
原创粉丝点击