决策树实战:从原理到实现
来源:互联网 发布:光猫超级密码开启端口 编辑:程序博客网 时间:2024/05/01 05:59
主要内容:
1. 决策树简介
2. 决策树的构建算法
3. python实现决策树
4. 扩展
5. 代码附录
6 .参考资料
决策树(decision tree) 是一种基本的分类和回归方法,决策树的学习包括3个步骤:特征选择、决策树的生成和决策树的剪枝。本文讨论决策树用作分类的原理和ID3算法的python实现。
1.决策树
决策树模型是一种描述对实例进行分类的树形结构。决策树由节点和有向边组成。节点有内部节点(internal node)和叶子节点(leaf node),内部节点表示一个特征(feature),叶子节点(node)表示一个类(label)。有向边表示相应特征的取值。如图所示是判断一个是否能够有偿还债务的能力的决策树:
2. 如何从历史数据构建决策树
如何一步步构建决策树,关键是找出相应的节点和边。我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类时起决定性作用。为了找出决定性的特征,划分出最好的结果,必须评估每个特征。信息增益石一种有效的评估方法。
2.1.信息增益(information gain)
信息增益表示得知特征X的信息而使得类Y的信息的不确定性减少程度
g(D,A)=H(D)-H(D|A)
H(D)数据D的经验熵, H(D|A) 特征A给定条件下D的经验条件熵, 也称为互信息,决策树学习中的信息增益等价于训练数据集中类与特征的互信息
基于信息增益准则的特征选择方法是:对训练数据集D,计算其每个特征的信息增益熵, 选择信息增益最大的特征。
信息增益算法:
2.2 决策树的生产
算法流程:
3. 实战决策树
首先获取数据:
def createDataSet(): #dataSet=pd.read_csv(datafile) # label=[] dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels=['no surfacing','flippers'] return dataSet, labels
根据数据递归生成树:
def createTree(dataSet,labels): classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] #如果改子集全部是同一类,返回该类 if len(dataSet[0])==1: return majorityCnt(classList) #返回叶子节点的值 bestFeat=chooseBestFeatureToSplit(dataSet) #选取最佳特征 bestFeatLabel=labels[bestFeat] myTree={bestFeatLabel:{}} del (labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals: subLabels=labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet\ #递归调用,采用字典方式存储树的结构 (dataSet,bestFeat,value),subLabels) return myTree
计算决策用到的子函数:
计算给定数据集的香农熵:
def calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet: currentLabel=featVec[-1] #计算feature值的频数,为所有可能的类创建字典 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannonEnt=0.0 #计算Shannon entropy for key in labelCounts: prob=float(labelCounts[key])/numEntries shannonEnt-=prob*log(prob,2) return shannonEnt
分割数据
def splitDataSet(dataSet, axis, value): retDataSet=[] for featVec in dataSet: if featVec[axis]==value: reducedFeatVec=featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet
选择最好的特征进行分割数据
def chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0; bestFeature=-1; for i in range(numFeatures): featList=[example[i] for example in dataSet] uniqueVals=set(featList) newEntropy=0.0 #计算每一种划分方式的信息熵 for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet) #信息增益熵 infoGain=baseEntropy - newEntropy #最好信息增益 if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature
#统计数据集出现频率最高的label
def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),\ key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
5.附录代码:.
5.1 trees.py
# -*- coding: utf-8 -*-"""Created on Fri Apr 03 09:13:58 2015@author: beta"""from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet): numEntries=len(dataSet) labelCounts={} for featVec in dataSet: currentLabel=featVec[-1] #计算feature值的频数,为所有可能的类创建字典 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0 labelCounts[currentLabel]+=1 shannonEnt=0.0 #计算Shannon entropy for key in labelCounts: prob=float(labelCounts[key])/numEntries shannonEnt-=prob*log(prob,2) return shannonEnt def createDataSet(): #dataSet=pd.read_csv(datafile) # label=[] dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']] labels=['no surfacing','flippers'] return dataSet, labelsdef splitDataSet(dataSet, axis, value): retDataSet=[] for featVec in dataSet: if featVec[axis]==value: reducedFeatVec=featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSetdef chooseBestFeatureToSplit(dataSet): numFeatures=len(dataSet[0])-1 baseEntropy=calcShannonEnt(dataSet) bestInfoGain=0.0; bestFeature=-1; for i in range(numFeatures): featList=[example[i] for example in dataSet] uniqueVals=set(featList) newEntropy=0.0 #计算每一种划分方式的信息熵 for value in uniqueVals: subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcShannonEnt(subDataSet) #信息增益熵 infoGain=baseEntropy - newEntropy #最好信息增益 if(infoGain>bestInfoGain): bestInfoGain=infoGain bestFeature=i return bestFeature def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote]=0 classCount[vote]+=1 sortedClassCount=sorted(classCount.iteritems(),\ key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] def createTree(dataSet,labels): classList=[example[-1] for example in dataSet] if classList.count(classList[0])==len(classList): return classList[0] if len(dataSet[0])==1: return majorityCnt(classList) bestFeat=chooseBestFeatureToSplit(dataSet) bestFeatLabel=labels[bestFeat] myTree={bestFeatLabel:{}} del (labels[bestFeat]) featValues=[example[bestFeat] for example in dataSet] uniqueVals=set(featValues) for value in uniqueVals: subLabels=labels[:] myTree[bestFeatLabel][value]=createTree(splitDataSet\ (dataSet,bestFeat,value),subLabels) return myTree if __name__=='__main__': myDat, labels=createDataSet() myTree=createTree(myDat,labels) treePlotter.createPlot(myTree)
'''Created on Oct 14, 2010@author: Peter Harrington'''import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()#def createPlot():# fig = plt.figure(1, facecolor='white')# fig.clf()# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)# plt.show()def retrieveTree(i): listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i]#createPlot(thisTree)
6.参考资料:
1. 李航 《统计学习方法》
2.. Peter Harrington 《机器学习实战》
- 决策树实战:从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 决策树--从原理到实现
- 【机器学习】决策树(上)——从原理到算法实现
- 干货|从决策树到随机森林:树型算法的实现原理与Python 示例
- 从决策树到GBDT
- Spring Boot从原理到实战
- CRC从原理到实现
- CRC从原理到实现
- CRC从原理到实现
- AdaBoost--从原理到实现
- SVM --从“原理”到实现
- 黑马程序员---java基础 -----基础知识,运算符,循环等
- 蓝桥杯---地宫取宝(记忆搜索=搜索+dp)
- poj2513—并查集+欧拉回路+线段树
- Android学习 - 权限跳转
- 子串和
- 决策树实战:从原理到实现
- 十七、类与对象:UML简介
- APP吃什么-经验
- android drawable Layer List
- java 程序初始化顺序
- MyBatis批量新增、更新
- LeetCode155:Min Stack
- java学习笔记——第9天
- 心目中的理想工作排行榜