决策树--ID3算法

来源:互联网 发布:分润系统源码 编辑:程序博客网 时间:2024/06/07 16:10

1、基本理论:熵、信息增益
http://www.cnblogs.com/wentingtu/archive/2012/03/24/2416235.html

2、ID3算法步骤:
输入:数据集dataset(所有样本的属性值),标签集labels(决策结果集)
输出:一颗判定树
(1)if dataset所有样本都属于同一分类(即只有天气晴才出去玩,其他情况都不出去,都属于天气这一分类)
返回标号为该分类的叶节点
(2)else if 属性值为空
返回标签中值相同数量最多的作为叶节点
(3)else 选择信息增益最高的属性最为根节点,接着判断改属性下是否有样本,如果没有,创建该属性下标号最普遍分类的叶子结点;如果有,则开始递归上述步骤(1)~(3)
http://blog.csdn.net/liema2000/article/details/6118384
具体实例分析:http://zc0604.iteye.com/blog/1462825

3、Python实现:
3.1计算数据集的香农熵:

#计算信息熵def calcShannonEnt(dataSet):    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet:        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel] = 0        labelCounts[currentLabel] +=1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        shannonEnt -= prob*log(prob,2)    return shannonEnt

3.2 准备数据:
这里写图片描述

def createDataSet():    dataSet = [[1,1,'yes'],               [1,1,'yes'],               [1,0,'no'],               [0,1,'no'],               [0,1,'no']]    labels = ['no surfacing','flippers']    return dataSet,labels

3.3 划分数据集

#划分数据集,按照给定的特征划分数据集,返回同一属性不同属性值的数据集def splitDataSet(dataSet,axis,value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]            reducedFeatVec.extend(featVec[axis+1:])            retDataSet.append(reducedFeatVec)    return retDataSet

3.4 选择最好的数据集划分方式:即选择信息增益最大的属性

#选择最好的数据集划分方式def chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0])-1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0;    bestFeature = -1    for i in range(numFeatures):        featList = [example[i] for example in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet,i,value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob*calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy        if (infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFeature = i    return bestFeature

以上是构造决策树所需的子功能模块,通过chooseBestFeatureToSplit函数找到划分数据集的最好属性,在该属性下会得到几个分支,然后在这几个分支下继续划分数据,在此就用到了递归。

在递归算法中,最重要的就是终止条件。决策树的递归终止条件是:
(1)程序遍历完所有划分数据的属性 或者 (2)每个分支下的所有实例都具有相同的分类

如果所有实例都具有相同的分类,则得到一个叶子结点或终止块。任何达到叶子结点的数据必然属于叶子结点的分类。

如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,通常会采用多数表决的方法决定该叶子节点分类

3.5 多数表决算法

def majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys():            classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)    return sortedClassCount

3.6 创建决策树

#用于创建树的函数代码def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):    #判断是属于数据否属于同一类        return classList[0]    if (len(dataSet[0]) ==1):                              #遍历完所有特征,返回出现次数最多的标签        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)    return myTree

添加如下代码运行

myDat,labels = createDataSet()myTree = createTree(myDat,labels)print myTree

输出为:{'no surfacing': {0: 'no', 1: {'flipprers': {0: 'no', 1: 'yes'}}}}

4 、绘制决策树
4.1 例子:

import matplotlib.pyplot as plt#定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth",fc="0.8")leafNode =dict(boxstyle="round4",fc="0.8")arrow_args = dict(arrowstyle="<-")#绘制箭头的注解def plotNode(nodeTxt,centerPt,parentPt,nodeType):    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',                            xytext=centerPt,textcoords='axes fraction',                            va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)def createPlot():    fig = plt.figure(1,facecolor='white')    fig.clf()    createPlot.ax1 = plt.subplot(111,frameon=False)    #plotNode(U'decisionNode',(0.5,0.1),(0.1,0.5),decisionNode)    #decisionNode:文本显示的内容    #(0.5,0.1):文本所在位置坐标    #(0.1,0.5):实际点的坐标    #decisionNode:自定义的文本框的类型    plotNode(U'decisionNode',(0.5,0.1),(0.1,0.5),decisionNode)    plotNode(U'leafNode',(0.8,0.1),(0.3,0.8),leafNode)    plt.show()createPlot()

这里写图片描述

4.2 为了绘制各个节点,需要获取决策树的深度(决定图的高度y)以及叶子节点数(决定图的宽度x)

#构造注解树,需要知道有多少节点,以便确定x轴的长度,知道多少层,以便确定y轴的高度def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumLeafs(secondDict[key])        else: numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1+getTreeDepth(secondDict[key])        else: thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepth def retrieveTree(i):    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},                   {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}                   ]    return listOfTrees[i]numLeafs = getNumLeafs(retrieveTree(0))depth = getTreeDepth(retrieveTree(0))print numLeafsprint depth

4.3 需要修改前面定义的plotTree()函数:

def plotMidText(cntrPt,parentPt,txtString):    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):    numLeafs = getNumLeafs(myTree)    depth = getTreeDepth(myTree)    firstStr = myTree.keys()[0]    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)    plotMidText(cntrPt,parentPt,nodeTxt)    plotNode(firstStr,cntrPt,parentPt,decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes               plotTree(secondDict[key],cntrPt,str(key))        #recursion        else:   #it's a leaf node print the leaf node            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses     plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;    plotTree(inTree, (0.5,1.0), '')    plt.show()myTree = retrieveTree(0)createPlot(myTree)myTree['no surfacing'][2] = 'maybe'createPlot(myTree)

这里写图片描述

这里写图片描述
5.测试:使用决策树进行分类
5.1 分类函数:

#使用决策树分类的函数def classify(inputTree,featLabels,testVec):    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)  #返回属性标签的下标    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__ =='dict':                classLabel = classify(secondDict[key],featLabels,testVec)            else: classLabel = secondDict[key]    return classLabelmyDat,labels = createDataSet()myTree = retrieveTree(0)result = classify(myTree,labels,[1,1])print result

5.2 存储决策树:
可以把构造好的决策树存储起来,以后可以直接调用进行分类

#存储决策树import pickledef storeTree(inputTree,filename):    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    fr = open(filename,'r')    return pickle.load(fr)storeTree(myTree,'classifierStorage.txt')storageTree = grabTree('classifierStorage.txt')print "storageTree: %r" %storageTree

6 实际应用:预测患者佩戴隐形眼镜类型

#读入数据fr = open('lenses.txt')#预处理数据lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age','prescript','astigmatic','tearRate']lensesTree = createTree(lenses,lensesLabels)print "构造的决策树:%r" %lensesTreecreatePlot(lensesTree)

这里写图片描述

0 0
原创粉丝点击