决策树
来源:互联网 发布:js判断不等于0 编辑:程序博客网 时间:2024/06/08 19:16
决策树的一般流程:
- 收集数据:anymethd
- 准备数据:树构造算法只适用于标称型数据,因此数值型必须离散化
- 分析数据:可以使用任何方法,树构造完成后应该检查图形是否符合预期
- 训练算法:构造树的数据结构
- 测试算法:使用经验树计算错误率
- 使用算法:适用于任何监督学习算法,而使用决策树可更好的理解数据的内在含义
计算给定数据集的香农熵(集合信息的度量方式,度量数据集的无序程度)
熵定义为信息的期望值,熵越高,混合的数据也越多。
def calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet:#统计所有类标签发生的次数 currentLabel=featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannoEnt=0.0 for key in labelCounts: prob=float(labelCounts[key])/numEntries shannoEnt-=prob*log(prob,2) return shannoEntdef createDataSet():#创建简单数据集 dataSet=[[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels=['no surfacing','flippers'] return dataSet,labels
划分数据集(度量划分数据集的熵,判断当前是否正确划分了数据集)
对每个特征划分数据集的结果计算一次信息熵,然后判断哪个特征是划分数据集的最好划分方式。
#splitDataSet(待划分的数据集、划分数据集的特征、特征的返回值)def splitDataSet(dataset,axis,value): retDataSet=[] for featVec in dataset: if featVec[axis]==value: #判断特征值为axis的列,其值是否等于value, #splitDataSet(dataset,0,0)---即判断featVec[0]是否等于0 reducedFeatVec=featVec[:axis] print(reducedFeatVec) reducedFeatVec.extend(featVec[axis+1:]) print(reducedFeatVec) retDataSet.append(reducedFeatVec) print(retDataSet) return retDataSetm,l=createDataSet()calcShannonEnt(m)splitDataSet(m,0,0)[][1, 'no'][[1, 'no']][][1, 'no'][[1, 'no'], [1, 'no']]#-----------------------------------------------------------------------------------------------#遍历整个数据集,循环计算熵和splitDataSet(),找到最好的特征划分方式。def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1#计算特征值数 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0 bestFeature=-1#最好的特征值 for i in range(numFeatures): featList=[example[i] for example in dataSet]#列表推导 i遍历dataSet,并且将每行的第i列存放到example uniqueVals=set(featList)#Build an unordered collection of unique elements. newEntropy=0.0 for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet))#选择该分类的概率 newEntropy+=prob*calcShannonEnt(subDataSet)#新的熵值 infoGain=baseEntropy-newEntropy if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature#返回最好的特征值
递归构建决策树
工作原理:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于2个分支的数据集划分。
采用递归处理:结束条件为(遍历完所有划分数据集的属性 or 每个分支下的所有实例都具有相同的分类)
def majorityCnt(classList):#返回次数最多的分类名称 classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount+=1 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount[0][0]def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList):#所有的类标签完全相同,则直接返回该类标签 return classList[0] if len(dataSet[0]) == 1:#遍历完所有特征仍然不能把数据集划分成仅包含唯一类别的分组,则返回出现次数最多的返回 return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] #copy all of labels,不改变原始列表的内容 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree使用Matplotlib注解绘制树形图
1.Matplotlib注解
import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType): #import matplotlib.pyplot as plt #plt.annotate()文本注释 createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )# =============================================================================#subplot(row,col,plotNum)# 将绘图区域分为row*col列子域,并且按照从左往右,上到下对每个# 子区域编号,若row,col,plotNum都小于10,可用3位数字之间代替# 在plotNum的区域中创建轴对象。#plot(*args, **kwargs)# plot(x, y) # plot x and y using default line style and color# plot(x, y, 'bo') # plot x and y using blue circle markers#figure(num=None, figsize=None, dpi=None, facecolor=None,# edgecolor=None, frameon=True, FigureClass=<class 'matplotlib.figure.Figure'>, # **kwargs) #Creates a new figure#clf()# Clear the current figure.#createPlot.ax1# import dis# def func():# func().ax1=123# dis.dis(func)# 一切皆对象,对函数也可以添加属性# =============================================================================def createPlot(): fig = plt.figure(1, facecolor='white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode) plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show()
2.构造注解树
def getNumLeafs(myTree):#获取叶节点 numLeafs=0 firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs+=getNumLeafs(secondDict[key]) else: numLeafs+=1 return numLeafsdef getTreeDepth(myTree):#获取树的层数 maxDepth=0# python 2.x D.keys()->list# python 3.x D.keys(...)-> a set-like object providing a view on D's keys firstStr=list(myTree.keys())[0] secondDict=myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': thisDepth=1+getTreeDepth(secondDict[key]) else: thisDepth=1; if thisDepth>maxDepth: maxDepth=thisDepth return maxDepthdef retrieveTree(i):#输出预先存储的树信息,主要用于测试 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i]myTree=retrieveTree(0)getNumLeafs(myTree)# 3getNumTreeDepth(myTree)# 2
------------------------以下将前面所学组合一起,绘制一颗完整树-----------------------------
def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops=dict(xticks=[],yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False,**axprops) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops=dict(xticks=[],yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False,**axprops) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()测试和存储分类器
#1.测试算法:创建使用决策树的分类器def classify(inputTree,featLabels,testVec):#使用决策树的分类函数 firstStr=list(inputTree.keys())[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr)#将标签字符串转换为索引 for key in secondDict.keys(): if testVec[featIndex]==key:#比较testVec变量中的值与树节点的值 if type(secondDict[key]).__name__=='dict': classLabel=classify(secondDict[key],featLabels,testVec) else: classLabel=secondDict[key] return classLabel# =============================================================================# In[]: m,l=trees.createDataSet()# Out[]: ['no surfacing', 'flippers']# # In[]: myTree=treePlotter.retrieveTree(0)# # In[]: myTree# Out[]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}# In[]: trees.classify(myTree,l,[1,0])# Out[]: 'no'# =============================================================================#2.使用算法:决策树的存储def storeTree(inputTree,filename): import pickle fw = open(filename,'w') pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)。
阅读全文
0 0
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 决策树
- 表单设计案例——学生入校注册页面设计
- 问题:2006年培养学员八万人,每年增长25%,请问按此增长速度,到哪一年培训学员人数达到20万人?
- Spring源码分析:多个bean同时满足注入条件时,spring的选择
- deeplearning_Planardataclassificationwithonehiddenlayer
- Qemu for windows 使用实例
- 决策树
- 操作系统_信号量
- 筛法求素数
- 第二章课后习题4
- Linux教程【5】- 文件处理命令
- QT 键盘 很多按键的处理方法
- LeetCode-ArrayAndDigit
- 游标
- 测试问题