决策树

来源:互联网 发布:网络直销流程 编辑:程序博客网 时间:2024/06/04 00:49

一, 信息增益

H = - sum( p(xi) *log2(p(xi)) )

from math import logimport operatordef createDataSet():dataSet = [ [1, 1, 'yes'],            [1, 1, 'yes'],            [1, 0, 'no'],            [0, 1, 'no'],            [0, 1, 'no']]labels = ['no surfacing', 'flippers']return dataSet, labelsdef calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():  labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob, 2)return shannonEnt
>>> import trees>>> myDat,labels=trees.createDataSet()>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> trees.calcShannonEnt(myDat)0.9709505944546686>>> myDat[0][-1]='maybe'>>> trees.calcShannonEnt(myDat)1.3709505944546687>>> myDat,labels=trees.createDataSet()>>> trees.calcShannonEnt(myDat)0.9709505944546686>>> myDat[0][0] = 0>>> trees.calcShannonEnt(myDat)0.9709505944546686

二,划分数据集

def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet
>>> reload(trees)<module 'trees' from 'trees.py'>>>> myDat,labels=trees.createDataSet()>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> trees.splitDataSet(myDat, 0, 1)[[1, 'yes'], [1, 'yes'], [0, 'no']]>>> trees.splitDataSet(myDat, 0, 0)[[1, 'no'], [1, 'no']]
选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature

>>> reload(trees)<module 'trees' from 'trees.py'>>>> myDat,labels=trees.createDataSet()>>> trees.chooseBestFeatureToSplit(myDat)0>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

三,递归构建决策树

递归构建决策树的结束条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有相同的分类.如果所有实例具有相同的分类,则得到一个叶子节点或者终止块

def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount.keys(): classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]

构建决策树

def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0] #stop splitting when all of the classes are equalif len(dataSet[0]) == 1:  #stop splitting when there are no more features in dataSetreturn majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labelsmyTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)return myTree
测试

>>> reload(trees)<module 'trees' from 'trees.py'>>>> myDat,labels=trees.createDataSet()>>> myTree = trees.createTree(myDat,labels)>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
保存决策树

def storeTree(inputTree,filename):    import pickle    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()    def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)

四,使用Matplotlib注解绘制树形图

注解工具annotations,可以在数据图上添加文本注解

treePlotter.py

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )def createPlot():fig = plt.figure(1, facecolor='white')fig.clf()createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotNode(U'decisionNode', (0.5, 0.1), (0.1, 0.5), decisionNode)plotNode(U'leafNode', (0.8, 0.1), (0.3, 0.8), leafNode)plt.show()

>>> import treePlotter>>> treePlotter.createPlot()
构造注解树

def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes            numLeafs += getNumLeafs(secondDict[key])        else:   numLeafs +=1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes            thisDepth = 1 + getTreeDepth(secondDict[key])        else:   thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepthdef retrieveTree(i):    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                  ]    return listOfTrees[i]
>>> reload(treePlotter)<module 'treePlotter' from 'treePlotter.py'>>>> treePlotter.retrieveTree(1){'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}>>> myTree = treePlotter.retrieveTree(0)>>> treePlotter.getNumLeafs(myTree)3>>> treePlotter.getTreeDepth(myTree)2
绘制树

def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree    depth = getTreeDepth(myTree)    firstStr = myTree.keys()[0]     #the text label for this node should be this    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes               plotTree(secondDict[key],cntrPt,str(key))        #recursion        else:   #it's a leaf node print the leaf node            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;plotTree(inTree, (0.5,1.0), '')plt.show()
>>> reload(treePlotter)<module 'treePlotter' from 'treePlotter.py'>>>> myTree = treePlotter.retrieveTree(0)>>> treePlotter.createPlot(myTree)

五,使用决策树执行分类

def classify(inputTree,featLabels,testVec):    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    key = testVec[featIndex]    valueOfFeat = secondDict[key]    if isinstance(valueOfFeat, dict):         classLabel = classify(valueOfFeat, featLabels, testVec)    else: classLabel = valueOfFeat    return classLabel
>>> reload(trees)<module 'trees' from 'trees.pyc'>>>> myDat,labels=trees.createDataSet()>>> myTree = treePlotter.retrieveTree(0)>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> trees.classify(myTree,labels,[1,0])'no'>>> trees.classify(myTree,labels,[1,1])'yes'

六,示例:使用决策树预测隐形眼镜类型

>>> fr=open('lenses.txt')>>> lenses=[inst.strip().split('\t') for inst in fr.readlines()]>>> lensesLabels=['age','prescript','astigmatic','tearRate']>>> lensesTree=trees.createTree(lenses,lensesLabels)>>> lensesTree{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}>>> treePlotter.createPlot(lensesTree)







原创粉丝点击