《机器学习实战》学习笔记<二>决策树
来源:互联网 发布:sql insert date 编辑:程序博客网 时间:2024/05/18 01:05
trees.py
# -*- coding: utf-8 -*- from math import logimport operatordef calcShannonEnt(dataSet):#计算信息熵 numEntries=len(dataSet) labelsCounts={}#集合总数 for featVec in dataSet: currentLabel=featVec[-1]#创建数据字典,键值为最后一行数据 if currentLabel not in labelsCounts.keys():#如果当前键值不存在,则将当前键值加入字典 labelsCounts[currentLabel] = 0 labelsCounts[currentLabel] += 1#记录标签发生次数 shannonEnt = 0.0 for key in labelsCounts: prob = float(labelsCounts[key])/numEntries#发生频率 shannonEnt -= prob*log(prob,2)#信息熵公式 return shannonEntdef createDataSet():#数据集 dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels=['no surfacing','flippers']#用0,1表示 return dataSet,labelsdef splitDataSet(dataSet,axis,value):#数据集,数据的特征,特征的值 retDataSet=[]#创建新列表 for featVec in dataSet:#遍历列表元素 if featVec in dataSet: if featVec[axis]==value:#符合条件的话,添加到新建列表retDataSet reducedFeatVec=featVec[:axis]#切片操作 reducedFeatVec.extend(featVec[axis+1:])#添加元素 retDataSet.append(reducedFeatVec)#添加列表 return retDataSetdef chooseBestFeatureToSplit(dataSet):#选取特征,划分数据集 numFeature=len(dataSet[0])-1 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0;bestFeature=-1#保存最初的无序值 for i in range(numFeature):#遍历数据集的特征 featList=[example[i] for example in dataSet] uniqueVals=set(featList)#将数据集的特征值写入新列表 newEntropy=0.0 for value in uniqueVals:#遍历当前特征中所有的唯一属性值 subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet)#计算数据集的新墒值 infoGain=baseEntropy-newEntropy if(infoGain>bestInfoGain):#比较所有特征值的信息增益 bestInfoGain=infoGain bestFeature=i return bestFeature#返回最好的划分def majorityCnt(classList):#创建数据字典 classCount={}#用来储存每个类标签出现的频率 for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)#排序字典 return sortedClassCount[0][0]#返回出现次数最多的分类名称def createTree(dataSet,labels):#数据集和标签列表 classList=[example[-1] for example in dataSet]#数据集类标签 if classList.count(classList[0])==len(classList):#第一个停止条件所有类标签相同 return classList[0]#返回该标签 if len(dataSet[0])==1:#第二个停止条件,不能分类时,返回出现次数最多的标签 return majorityCnt(classList) bestFeat=chooseBestFeatureToSplit(dataSet)#创建树,选择最好的存储在bestFeat bestFeatLabel=labels[bestFeat] myTree={bestFeatLabel:{}}#存储树的所有信息 del(labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals:#遍历所有属性值 subLabels=labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree def classify(inputTree,featLabels,testVec): firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key]).__name__=='dict': classLabel=classify(secondDict[key],featLabels,testVec) else:classLabel=secondDict[key] return classLabel def storeTree(inputTree,filename): import pickle fw=open(filename,'w') pickle.dump(inputTree,fw) fw.close()def grabTree(filename): import pickle fr=open(filename) return pickle.load(fr)
treesPlotter.py
# -*- coding: utf-8 -*-import matplotlib.pyplot as pltdecisionNode=dict(boxstyle="sawtooth",fc="0.8")#定义文本框和箭头格式leafNode=dict(boxstyle="round4",fc="0.8")arrow_args=dict(arrowstyle="<-")def plotNode(nodeTxt,centerPt,parentPt,nodeType):#绘制带箭头的注解 createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)def createPlot():#绘制图形 fig=plt.figure(1,facecolor='white') fig.clf() createPlot.ax1=plt.subplot(111,frameon=False) plotNode('yejuecedian',(0.5,0.1),(0.1,0.5),decisionNode) plotNode('yejiedian',(0.8,0.1),(0.3,0.8),leafNode) plt.show()def getNumLeafs(myTree):#叶子节点的数目,但是运行时报错type object 'str' has no attribute '_name_',已解决:改为__name__ numLeafs=0 firstStr=myTree.keys()[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#如果子节点是字典类型则该节点是一个判断节点 numLeafs +=getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafsdef getTreeDepth(myTree):#计算树的深度,运行时报错local variable 'firstStr' referenced before assignment maxDepth=0 firstStr=myTree.keys()[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth=1+getTreeDepth(secondDict[key]) else: thisDepth=1 if thisDepth>maxDepth:maxDepth=thisDepth return maxDepthdef retrieveTree(i):#用于测试树结构 listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}] return listOfTrees[i]def plotMidText(cntrPt,parentPt,txtString): xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0] yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString,va="center",ha="center",rotation=30)def plotTree(myTree,parentPt,nodeTxt):#画树 numLeafs=getNumLeafs(myTree) depth=getTreeDepth(myTree) firstStr=myTree.keys()[0] cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) # plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict=myTree[firstStr] plotTree.yOff=plotTree.yOff-1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff=plotTree.xOff+1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD def createPlot(inTree): fig=plt.figure(1,facecolor='white') fig.clf() axprops=dict(xticks=[],yticks=[]) createPlot.ax1=plt.subplot(111,frameon=False,**axprops) plotTree.totalW=float(getNumLeafs(inTree))#数的宽度 plotTree.totalD=float(getTreeDepth(inTree))#数的深度 plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0; plotTree(inTree,(0.5,1.0),'') plt.show()
命令行测试
>>> import treePlotter>>> reload(treePlotter)<module 'treePlotter' from 'D:\python2.7\treePlotter.pyc'>>>> myTree=treePlotter.retrieveTree(0)>>> treePlotter.createPlot(myTree)
>>> import treePlotter>>> reload(treePlotter)<module 'treePlotter' from 'D:\python2.7\treePlotter.pyc'>>>> myTree=treePlotter.retrieveTree(0)>>> myTree['no surfacing'][3]='maybe'>>> treePlotter.createPlot(myTree)
0 0
- 《机器学习实战》学习笔记<二>决策树
- 【机器学习实战】笔记二:决策树
- 机器学习实战(二)--决策树
- 机器学习实战笔记:决策树
- 机器学习实战笔记--决策树
- 机器学习实战笔记-决策树
- 机器学习实战-决策树笔记
- 机器学习实战笔记-决策树
- 机器学习笔记二------决策树
- 机器学习实战学习笔记-决策树
- 《机器学习实战》学习笔记---决策树
- 《机器学习实战》学习笔记 --chapter3 决策树
- 机器学习实战—决策树(二)
- Python机器学习实战(二)--决策树
- 机器学习实战---决策树
- 机器学习实战-决策树
- 机器学习实战---决策树
- 机器学习实战 决策树
- Redis-3.2.4集群配置(RedisCluster+SpringBoot+Jedis)
- Eclipse中实现同时多行注释
- 爬楼梯
- EasyPlayer声音自动停止、恢复,一键静音等功能
- Ubuntu配置环境变量
- 《机器学习实战》学习笔记<二>决策树
- 虚基类
- javascript中each方法的实现
- seleniumWebDriver自动化测试框架_02TestNG和txt文件进行数据驱动
- Java虚拟机以及跨平台原理
- 使用Spring Cloud Feign作为HTTP客户端调用远程HTTP服务
- D. String Game
- java学习日记1(HttpSession和Cookie)
- 欢迎使用CSDN-markdown编辑器