决策树--ID3算法
来源:互联网 发布:分润系统源码 编辑:程序博客网 时间:2024/06/07 16:10
1、基本理论:熵、信息增益
http://www.cnblogs.com/wentingtu/archive/2012/03/24/2416235.html
2、ID3算法步骤:
输入:数据集dataset(所有样本的属性值),标签集labels(决策结果集)
输出:一颗判定树
(1)if dataset所有样本都属于同一分类(即只有天气晴才出去玩,其他情况都不出去,都属于天气这一分类)
返回标号为该分类的叶节点
(2)else if 属性值为空
返回标签中值相同数量最多的作为叶节点
(3)else 选择信息增益最高的属性最为根节点,接着判断改属性下是否有样本,如果没有,创建该属性下标号最普遍分类的叶子结点;如果有,则开始递归上述步骤(1)~(3)
http://blog.csdn.net/liema2000/article/details/6118384
具体实例分析:http://zc0604.iteye.com/blog/1462825
3、Python实现:
3.1计算数据集的香农熵:
#计算信息熵def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] +=1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob,2) return shannonEnt
3.2 准备数据:
def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing','flippers'] return dataSet,labels
3.3 划分数据集
#划分数据集,按照给定的特征划分数据集,返回同一属性不同属性值的数据集def splitDataSet(dataSet,axis,value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet
3.4 选择最好的数据集划分方式:即选择信息增益最大的属性
#选择最好的数据集划分方式def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if (infoGain > bestInfoGain): bestInfoGain = infoGain bestFeature = i return bestFeature
以上是构造决策树所需的子功能模块,通过chooseBestFeatureToSplit函数找到划分数据集的最好属性,在该属性下会得到几个分支,然后在这几个分支下继续划分数据,在此就用到了递归。
在递归算法中,最重要的就是终止条件。决策树的递归终止条件是:
(1)程序遍历完所有划分数据的属性 或者 (2)每个分支下的所有实例都具有相同的分类
如果所有实例都具有相同的分类,则得到一个叶子结点或终止块。任何达到叶子结点的数据必然属于叶子结点的分类。
如果数据集已经处理了所有的属性,但是类标签依然不是唯一的,通常会采用多数表决的方法决定该叶子节点分类
3.5 多数表决算法
def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True) return sortedClassCount
3.6 创建决策树
#用于创建树的函数代码def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): #判断是属于数据否属于同一类 return classList[0] if (len(dataSet[0]) ==1): #遍历完所有特征,返回出现次数最多的标签 return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree
添加如下代码运行
myDat,labels = createDataSet()myTree = createTree(myDat,labels)print myTree
输出为:{'no surfacing': {0: 'no', 1: {'flipprers': {0: 'no', 1: 'yes'}}}}
4 、绘制决策树
4.1 例子:
import matplotlib.pyplot as plt#定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth",fc="0.8")leafNode =dict(boxstyle="round4",fc="0.8")arrow_args = dict(arrowstyle="<-")#绘制箭头的注解def plotNode(nodeTxt,centerPt,parentPt,nodeType): createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction', xytext=centerPt,textcoords='axes fraction', va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)def createPlot(): fig = plt.figure(1,facecolor='white') fig.clf() createPlot.ax1 = plt.subplot(111,frameon=False) #plotNode(U'decisionNode',(0.5,0.1),(0.1,0.5),decisionNode) #decisionNode:文本显示的内容 #(0.5,0.1):文本所在位置坐标 #(0.1,0.5):实际点的坐标 #decisionNode:自定义的文本框的类型 plotNode(U'decisionNode',(0.5,0.1),(0.1,0.5),decisionNode) plotNode(U'leafNode',(0.8,0.1),(0.3,0.8),leafNode) plt.show()createPlot()
4.2 为了绘制各个节点,需要获取决策树的深度(决定图的高度y)以及叶子节点数(决定图的宽度x)
#构造注解树,需要知道有多少节点,以便确定x轴的长度,知道多少层,以便确定y轴的高度def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1+getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def retrieveTree(i): listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}}, {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}} ] return listOfTrees[i]numLeafs = getNumLeafs(retrieveTree(0))depth = getTreeDepth(retrieveTree(0))print numLeafsprint depth
4.3 需要修改前面定义的plotTree()函数:
def plotMidText(cntrPt,parentPt,txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) plotMidText(cntrPt,parentPt,nodeTxt) plotNode(firstStr,cntrPt,parentPt,decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()myTree = retrieveTree(0)createPlot(myTree)myTree['no surfacing'][2] = 'maybe'createPlot(myTree)
5.测试:使用决策树进行分类
5.1 分类函数:
#使用决策树分类的函数def classify(inputTree,featLabels,testVec): firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) #返回属性标签的下标 for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ =='dict': classLabel = classify(secondDict[key],featLabels,testVec) else: classLabel = secondDict[key] return classLabelmyDat,labels = createDataSet()myTree = retrieveTree(0)result = classify(myTree,labels,[1,1])print result
5.2 存储决策树:
可以把构造好的决策树存储起来,以后可以直接调用进行分类
#存储决策树import pickledef storeTree(inputTree,filename): fw = open(filename,'w') pickle.dump(inputTree,fw) fw.close()def grabTree(filename): fr = open(filename,'r') return pickle.load(fr)storeTree(myTree,'classifierStorage.txt')storageTree = grabTree('classifierStorage.txt')print "storageTree: %r" %storageTree
6 实际应用:预测患者佩戴隐形眼镜类型
#读入数据fr = open('lenses.txt')#预处理数据lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age','prescript','astigmatic','tearRate']lensesTree = createTree(lenses,lensesLabels)print "构造的决策树:%r" %lensesTreecreatePlot(lensesTree)
- 决策树之id3算法
- 决策树ID3算法
- ID3决策树建立算法
- ID3 算法实现决策树
- 决策树ID3算法
- 决策树 ID3算法
- 决策树ID3算法
- 决策树之ID3算法
- (决策树)ID3算法
- 决策树之ID3算法
- 决策树: ID3算法
- 决策树,ID3算法
- 决策树之ID3算法
- 决策树(ID3算法)
- 决策树之 ID3 算法
- ID3决策树算法
- 决策树之ID3算法
- 决策树--ID3算法
- SELECT LAST_INSERT_ID() 的使用和注意事项
- hdoj 5878 I Count Two Three
- python爬虫(五)多页码
- android 实现aidl跨进程通信之一
- Mac安装TensorFlow
- 决策树--ID3算法
- RequestMapping的相关知识讲解(补)
- Windows2008 R2服务器配置TLS1.2方法
- 【EJB】Developing EJB Applications -- Chapter 2(创建企业级Bean项目)
- PAT A1086
- 孙云球(二分枚举)(AOJ 851)
- 程序员面试题:排序和查找的实现(JAVA版)
- Android 编译错误:unreachable statement
- PHP发送邮件