决策树算法及可视化实现
来源:互联网 发布:油漆调色软件下载 编辑:程序博客网 时间:2024/05/19 19:41
序
本文旨在对决策树算法的python实现及利用matplotlib绘制树进行学习。
算法描述
(1)最小二乘回归树生成算法
(2)CART生成算法
其中,5.25如下
1.计算给定数据集的香农熵
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: #the the number of unique elements and their occurance currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob,2) #log base 2 return shannonEnt
2.按照给定特征划分数据集
def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] #chop out axis used for splitting reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet
3.选择最好的数据集划分方式(特征)
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): #iterate over all the features featList = [example[i] for example in dataSet]#create a list of all the examples of this feature uniqueVals = set(featList) #get a set of unique values newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy if (infoGain > bestInfoGain): #compare this to the best gain so far bestInfoGain = infoGain #if better than current best, set to best bestFeature = i return bestFeature #returns an integer
4.多数表决,返回分类名称
def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
5.创建树
def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0]#stop splitting when all of the classes are equal if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] #copy all of labels, so trees don't mess up existing labels myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree
6.得到叶节点数目
def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs
7.得到树的层数
def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth
8.计算父、子节点的中间位置,并添加简单的文本标签信息
def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
9.使用文本注解,绘制节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
10.绘图
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD#if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()
11.主函数
fr = open('lenses.txt')lenses = [inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age','prescript','astigmatic','tearRate']lensesTree = createTree(lenses,lensesLabels)createPlot(lensesTree)
12.数据lenses.txt
young myope no reduced no lensesyoung myope no normal softyoung myope yes reduced no lensesyoung myope yes normal hardyoung hyper no reduced no lensesyoung hyper no normal softyoung hyper yes reduced no lensesyoung hyper yes normal hardpre myope no reduced no lensespre myope no normal softpre myope yes reduced no lensespre myope yes normal hardpre hyper no reduced no lensespre hyper no normal softpre hyper yes reduced no lensespre hyper yes normal no lensespresbyopic myope no reduced no lensespresbyopic myope no normal no lensespresbyopic myope yes reduced no lensespresbyopic myope yes normal hardpresbyopic hyper no reduced no lensespresbyopic hyper no normal softpresbyopic hyper yes reduced no lensespresbyopic hyper yes normal no lenses
13.图形
参考文献
(1)机器学习实战
(2)统计学习方法
0 0
- 决策树算法及可视化实现
- 决策树算法及实现
- 决策树算法及python实现
- 如何使用Pydev实现简单的决策树算法以及可视化
- 决策树算法的可视化表达
- 决策树算法原理及JAVA实现(ID3)
- ID3决策树算法原理及C++实现
- 机器学习算法及代码实现--决策树
- 决策树ID3算法及java实现
- 分类决策树简介及ID3算法实现
- 可视化决策树之Python实现
- 决策树算法及代码
- 决策树及提升算法
- ID3 算法实现决策树
- 决策树算法的实现
- Python实现决策树算法
- 决策树算法实现
- 决策树 - 算法基本实现
- Python 用文件保存游戏(2)
- C#编程规范
- 5-34 通讯录的录入与显示 (10分)
- 全局变量计数器
- js 提示框
- 决策树算法及可视化实现
- 【HDU-oj】-2124-Repair the Wall(贪心)
- uva10099+uvauva10048
- iOS7 StatusBar 使用小结
- hdu 3008
- 计数器闭包做法
- 带标题的 ViewPager 小实例
- Java基本语法-----java函数
- JS数组的操作方法