决策树绘图(二)

来源:互联网 发布:淘宝如何返利 编辑:程序博客网 时间:2024/05/19 03:27

由于最近在看机器学习实战,所以自己利用python3去完成里面的代码,此代码衔接着http://blog.csdn.net/xueyunf/article/details/9223865。

在这个基础上进行修改完成了这篇文章的代码,我们知道了决策树的简单构建,ID3算法完成,当然这都很基础,画图呢,只是为了让其更加形象化;我们添加几个函数,一个是输出一棵我们可以利用ID3算法生成的树,一个获取树的叶子节点,一个获取树的深度,这些我想这里就不用讲解了,学过数据结构的童鞋,可以在非常短的时间内实现这些算法;当然我先把这3个函数的代码贴出来:

def  getNumLeafs(myTree):    numLeafs = 0    firstStr = list(myTree.keys())[0]    secondDict  =  myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafs    def getTreeDepth(myTree):    maxDepth = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth>maxDepth:            maxDepth = thisDepth    return maxDepthdef retrieveTree(i):    listOfTrees = [{'no surfacing':{0:'no', 1:{'flippers':\                    {0:'no', 1:'yes'}}}},                   {'no surfacing':{0:'no', 1:{'flippers':\                      {0:{'head':{0:'no', 1:'yes'}}, 1:'no'}}}}                ]                     return listOfTrees[i]

我们不难看出根据定义第一个函数完成获取所有叶子节点的个数,第二个函数完成获取树的高度,第三个函数完成输出树。

然后我们放出这次的主要函数,修改后的绘图函数:

def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree)    getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,\    plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            plotTree(secondDict[key],cntrPt,str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),            cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;    plotTree(inTree, (0.5,1.0), '')    plt.show()

最后当然也是截个图给大家看看程序的运行情况:


好了,这里面的函数我想大家可以通过名字也知道每个函数干了些什么。

原创粉丝点击