机器学习之决策树 Decision Tree(二)Python实现

来源:互联网 发布:mac涂层脱落修复 编辑:程序博客网 时间:2024/04/30 04:17

计算给定数据集的熵

from math import log# 计算给定数据集的熵def calcShannonEnt(dataSet):numEntries = len(dataSet)labelCounts = {}for featVec in dataSet:# 为所有可能分类创建字典currentLabel = featVec[-1]  # 最后一列数据为键值if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob * log(prob, 2)return shannonEnt
划分数据集

def createDataSet():dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]labels = ['no surfacing', 'flippers']return dataSet, labels# 按照给定特征划分数据集def splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]# 抽取特征reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSet# 选择最好的数据集划分方式def chooseBestFeatureToSplit(dataSet):numFeatures = len(dataSet[0]) - 1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0; bestFeature = -1for i in range(numFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature
构建决策树

import operatordef majorityCnt(classList):classCount={}for vote in classList:if vot not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedCoassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse = True)return sortedClassCount[0][0]# 创建树def createTree(dataSet,labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0]) == len(classList): # 类别完全相同则停止继续划分return classList[0]if len(dataSet[0]) == 1:# 遍历完所有特征时返回出现次数最多的return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree
使用决策树进行分类
# 使用决策树的分类函数def classify(inputTree, featLabels, testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featIndex = featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex] == key:if type(secondDict[key]).__name__ == 'dict':classLabel = classify(secondDict[key], featLabels, testVec)else:classLabel = secondDict[key]return classLabel
绘制决策树

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8") # 定义文本框和箭头格式leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox = nodeType, arrowprops=arrow_args)# 在父子节点间填充文本信息def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = myTree.keys()[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.axl = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5/plotTree.totalWplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()# 获取叶节点的数目和树的层数def getNumLeafs(myTree):numLeafs = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():# 如果子节点是字典类型,则该节点也是判断节点if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):maxDepth = 0firstStr = myTree.keys()[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth


阅读全文
0 0
原创粉丝点击