决策树 源码

来源:互联网 发布:手机版极简付款软件 编辑:程序博客网 时间:2024/06/03 22:30
决策树源码 trees.py
#!/usr/bin/python# -*- coding:utf-8 -*-from math import logimport operatordef calShannonEnt(dataSet):#计算信息熵numEntries = len(dataSet)labelCounts = {}#使用一个元组来存储每种类别出现的次数for featVec in dataSet:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] = labelCounts[currentLabel]+1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt = shannonEnt-prob*log(prob,2)return shannonEntdef createDataSet():#创建数据集dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]labels = ['no surfacing','flippers']return dataSet, labelsdef splitDataSet(dataSet,axis,value):#如果样本的特定属性为指定值,就将样本对应属性删除并加入返回集中retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet#返回axis属性为value的数据def chooseBestFeatureToSplit(dataSet):#选择信息增益最大的属性作为划分属性numFeatures = len(dataSet[0])-1 #特征总数baseEntropy = calShannonEnt(dataSet)bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]#得到所有数据第i个特征的取值uniqueVals = set(featList)#第i个特征的所有可能取值newEntropy = 0.0for value in uniqueVals:#找到数据集中第i个属性为这些唯一值的数据并求信息熵之和subDataSet = splitDataSet(dataSet,i,value)prob = len(subDataSet)/float(len(dataSet))newEntropy = newEntropy+prob*calShannonEnt(subDataSet)#按照第i个属性划分数据集的信息熵infoGain = baseEntropy-newEntropy#求信息增益if infoGain>bestInfoGain:#求最大信息增益并保存相应属性对应的下标bestInfoGain = infoGainbestFeature = ireturn bestFeaturedef majorityCnt(classList):#根据投票法确定分类结果,输入为一组元素的类别向量classCount={}#存储种类及数量for vote in classList:if vote not in classCount.keys():classCount[vote] = 0classCount[vote] = classCount[vote]+1sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reveres = True)#对结果元组进行排序return sortedClassCount[0][0]def createTree(dataSet, labels):#创建决策树classList = [example[-1] for example in dataSet]#数据种类的所有可能if classList.count(classList[0]) == len(classList):#种类唯一,则直接返回该种类(第一个结束条件)return classList[0]#如果递归结束都是返回一个值if len(dataSet[0]) == 1:#如果所有属性都划分过了(数据集只有种类,没有属性了),则投票法决定分类结果(第二个结束条件)return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)#选择最好的划分属性bestFeatLabel = labels[bestFeat]#得到划分属性对应的文字标签myTree = {bestFeatLabel:{}}#要返回的决策树del(labels[bestFeat])#在标签中删除选用的属性featValues = [example[bestFeat] for example in dataSet]#得到每条数据划分属性的值uniqueVals = set(featValues)#得到划分属性的所有可能值for value in uniqueVals:#数据在该属性上取不同的值就划分到不同的组中subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)return myTreedef classify(inputTree,featLabels,testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featIndex = featLabels.index(firstStr)#寻找当前的决策属性是第几个属性for key in secondDict.keys():if testVec[featIndex] == key:#检查测试样本中这个属性的值等于决策树中的哪个keyif type(secondDict[key]).__name__ == 'dict':#检测到对应的key之后就检查是继续分类还是得到最终分类结果classLabel = classify(secondDict[key],featLabels,testVec)else:classLabel = secondDict[key]return classLabeldef storeTree(inputTree, filename):#存储构建好的决策树import picklefw = open(filename,'w')pickle.dump(inputTree,fw)fw.close()def grabTree(filename):#读取存储的决策树import picklefr = open(filename)return pickle.load(fr)


画决策树的文件 treePlotter.py

#!/usr/bin/python# -*- coding:utf-8 -*-import matplotlib.pyplot as pltdecisionNode = dict(boxstyle = "sawtooth", fc = "0.8")#设置结点的形式和底色(0~1),值越大越浅leafNode = dict(boxstyle="round4", fc = "0.8")arrow = dict(arrowstyle="<-")#设置连线形式def plotNode(nodeTxt, centerPt, parentPt, nodeType):#根据位置文字和各种形式绘制结点和父子结点之间的连线createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow)def createPlot(inTree):#创建一个树fig=plt.figure(1,facecolor='white')fig.clf()axprops = dict(xticks=[0.5,1],yticks=[0.5])#显示哪些坐标轴上的值createPlot.ax1 = plt.subplot(111,frameon=False,**axprops)plotTree.totalW = float(getNumLeafs(inTree))#树的总宽度plotTree.totalD = float(getTreeDepth(inTree))#树的总高度plotTree.xOff = -0.5/plotTree.totalWplotTree.yOff = 1.0plotTree(inTree,(0.5,1.0),'')plt.show()def getNumLeafs(myTree):#得到叶节点的数量numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs+=getNumLeafs(secondDict[key])else:numLeafs+=1return numLeafsdef getTreeDepth(myTree):#得到树的深度maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1+getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth>maxDepth:maxDepth = thisDepthreturn maxDepthdef retrieveTree(i):#初始化数据    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                  ]    return listOfTrees[i]def plotMidText(cntrPt,parentPt,txtString):#画出两个结点间的标注信息xMid = (parentPt[0]-cntrPt[0])/2.0+cntrPt[0]yMid = (parentPt[1]-cntrPt[1])/2.0+cntrPt[1]createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = myTree.keys()[0]cntrPt = (plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)#计算子节点的位置(不知道怎么确定的X坐标)plotMidText(cntrPt,parentPt,nodeTxt)plotNode(firstStr,cntrPt,parentPt,decisionNode)#画出当前的决策结点secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff-1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':#如果之后也是决策结点则递归调用这个函数plotTree(secondDict[key],cntrPt,str(key))else:#否则直接画出子叶节点plotTree.xOff = plotTree.xOff+1.0/plotTree.totalWplotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))plotTree.yOff = plotTree.yOff+1.0/plotTree.totalD


原创粉丝点击