决策树(Python实现)
来源:互联网 发布:淘宝定制包 编辑:程序博客网 时间:2024/06/06 00:30
这篇文章是《机器学习实战》(Machine Learning in Action)第三章 决策树算法的Python实现代码。
1 参考链接
机器学习实战
2 实现代码
2.1 treePlotter.py
import matplotlib.pyplot as pltdesicionNode = dict(boxstyle='sawtooth', fc='0.8')leafNode = dict(boxstyle='round4', fc='0.8')arrow_args = dict(arrowstyle='<-')def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt,xycoords='axes fraction',\ xytext=centerPt, textcoords='axes fraction', va='center',\ ha='center', bbox=nodeType, arrowprops=arrow_args)def createPlot(): fig = plt.figure(1, facecolor='white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon=False) plotNode(U'desicion',(0.5,0.1),(0.1,0.5), desicionNode) plotNode(U'leaf', (0.8,0.1),(0.3,0.8), leafNode) plt.show()def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.xOff = 0.0 plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show()def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef retrieveTree(i): listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':\ {0:'no', 1:'yes'}}}}, {'no surfacing':{0:'no',1:{'flippers':\ {0:{'head':{0:'no',1:'yes'}}, 1:'no'}}}}] return listOfTrees[i]def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, desicionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD# TEST#myTree = retrieveTree(0)#myTree['no surfacing'][3] = 'maybe'#createPlot(myTree)
2.2 trees.py
from math import logimport operatorimport treePlotterdef createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing', 'flippers'] return dataSet, labelsdef calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob*log(prob, 2) return shannonEntdef splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSetdef chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0])-1 baseEntropy = calcShannonEnt(dataSet) baseInfoGain =0.0 bestFeature = -1 for i in range(numFeatures): featList = [example[i] for example in dataSet] uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy if(infoGain > baseInfoGain): baseInfoGain = infoGain bestFeature = i return bestFeaturedef majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount += 1 sortedClassCount = sorted(classCount.iteritems(),\ key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del (labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet\ (dataSet, bestFeat, value), subLabels) return myTreedef classify(inputTree, featLabels, testVec): firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabeldef storeTree(inputTree, filename): import pickle fw = open(filename, 'w') pickle.dump(inputTree, fw) fw.close()def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)# TESTfr = open('lenses.txt')lenses=[inst.strip().split('\t') for inst in fr.readlines()]lensesLabels = ['age', 'prescript','astigmatic', 'tearRate']lenseTree = createTree(lenses, lensesLabels)treePlotter.createPlot(lenseTree)
3 运行结果
0 0
- 决策树(Python实现)
- 决策树算法实现(python)
- # 详解决策树、python实现决策树
- python实现决策树(ID3算法)
- 决策树(ID3,C4.5)Python实现
- 机器学习之决策树实现(Python)
- 决策树原理实例(python代码实现)
- ID3决策树算法(python实现)
- C4.5决策树算法(Python实现)
- python实现决策树分类(一)
- python实现决策树分类(二)
- python实现决策树分类(三)
- 决策树(ID3算法)Python实现
- Python实现决策树(ID3、C4.5)
- Python实现决策树算法
- 决策树--Python实现
- 决策树及其python实现
- 决策树原理-python实现
- linux 自旋锁
- web前端的优化方案
- C#中基于GDI+(Graphics)图像处理系列之前言
- 规范化理论-函数依赖-范式-简单粗暴
- 关于C++中的指针应用
- 决策树(Python实现)
- web项目由jetty启动转成tomcat启动
- 为什么java没有多继承
- 检查tar版本的shell脚步checktar.sh
- 数据分析与数据挖掘在常规工作中的应用——基本统计量描述
- linux top 命令详解
- 栈的插入 删除
- DOS for命令详解
- 通信中使用的数据格式(xml,json,pb.msgpack)