决策树——python(机器学习实战)
来源:互联网 发布:centos git 编译 编辑:程序博客网 时间:2024/04/30 03:16
- 原理
- 步骤分解
- 遍历数据集 循环计算提取每个特征的香农熵和信息增益 选取信息增益最大的特征 再递归计算剩余的特征顺序 将特征排序 并将分类结果序列化保存到磁盘当中
- 递归构建决策树
- 使用Matplotlib注解绘制树形图
- 完整代码
原理
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类
步骤分解
1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中
def chooseBestFeatureToSplit(dataSet): # 选择最好的分类特征 """ :param dataSet: 原数据集 :return: 最好的划分特征的索引值 """ numFeatures = len(dataSet[0]) - 1 # 获取特征数 baseEntropy = calcShannonEnt(dataSet) # 计算数据集的信息熵 bestInfoGain = 0.0 # 初始化最好的信息熵 bestFeature = -1 # 初始化最好的用于分割的特征 for i in range(numFeatures): # 创建唯一的分类标签列表 featList= [example[i] for example in dataSet] # 获取每个元素的第i个特征 uniqueVals = set(featList) # 数据特征去重 (此特征有几种情况) newEntropy = 0.0 # 计算每种划分方式的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) # probability,概率,可理解为权重 newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大 # 计算最好的信息增益 if(infoGain > bestInfoGain): # 若新的信息增益大于之前的信息增益,则替换 bestInfoGain = infoGain bestFeature = i # 表示最好的划分特征的索引值 return bestFeature
2.递归构建决策树
def createTree(dataSet, labels): """ :param dataSet: 数据集 :param labels: 标签列表, 包含了数据集中的所有特征的标签 :return: """ classList = [example[-1] for example in dataSet] # 类别完全相同则停止继续划分 if classList.count(classList[0]) == len(classList): return classList[0] # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别 if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) # 选取的最好特征 bestFeatLabel = labels[bestFeat] # 最好特征标签 myTree = {bestFeatLabel:{ }} # 使用字典存储树信息 # 得到列表包含的所有属性值 del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # 因为下一步传参数时是引用传参 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree
3.使用Matplotlib注解绘制树形图
import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8") # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-") # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解 createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree): fig = plt.figure(1, facecolor='white') # 设置背景色 fig.clf() # 清空画布 axprops = dict(xticks=[], yticks=[]) createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框 plotTree.totalW = float(trees.getNumLeafs(inTree)) plotTree.totalD = float(trees.getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), ' ') # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode) # 第一个坐标是注解的坐标 第二个坐标是点的坐标 # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息 xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt): #计算宽与高 numLeafs = trees.getNumLeafs(myTree) depth = trees.getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # 根节点 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) # 标记子节点属性值 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] # 减少y偏移 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}] return listOfTrees[i]
完整代码
trees.py
from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet): # 计算给定数据集的香农熵 numEntries = len(dataSet) labelCounts = {} # 为所有可能的分类创建字典 for featVec in dataSet: currentLabel = featVec[-1] # if currentLabel not in labelCounts.keys(): # labelCounts[currentLabel] = 0 # labelCounts[currentLabel] += 1 labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 shannonEnt = 0.0 for key in labelCounts: # 以2为底求对数 prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob, 2) return shannonEntdef splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集 """ :param dataSet: 待划分的数据集 :param axis: 划分数据集的特征 :param value: 特征的返回值 :return: """ # 创建新的list对象 retDataSet = [] for featVec in dataSet: if featVec[axis] == value: # 抽取 reducedFratVec = featVec[:axis] reducedFratVec.extend(featVec[axis+1:]) retDataSet.append(reducedFratVec) return retDataSetdef chooseBestFeatureToSplit(dataSet): # 选择最好的分类特征 """ :param dataSet: 原数据集 :return: 最好的划分特征的索引值 """ numFeatures = len(dataSet[0]) - 1 # 获取特征数 baseEntropy = calcShannonEnt(dataSet) # 计算数据集的信息熵 bestInfoGain = 0.0 # 初始化最好的信息熵 bestFeature = -1 # 初始化最好的用于分割的特征 for i in range(numFeatures): # 创建唯一的分类标签列表 featList= [example[i] for example in dataSet] # 获取每个元素的第i个特征 uniqueVals = set(featList) # 数据特征去重 (此特征有几种情况) newEntropy = 0.0 # 计算每种划分方式的信息熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet) / float(len(dataSet)) # probability,概率,可理解为权重 newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大 # 计算最好的信息增益 if(infoGain > bestInfoGain): # 若新的信息增益大于之前的信息增益,则替换 bestInfoGain = infoGain bestFeature = i # 表示最好的划分特征的索引值 return bestFeaturedef majorityCnt(classList): # 多数表决决定叶子节点的分类 """ :param classList: 类别列表 :return: 出现次数最多的分类名称 """ classCount = {} for vote in classList: # 统计分类列表中个类别出现的次数 # if vote not in classCount.keys(): classCount[vote] = 0 # classCount[vote] += 1 classCount[vote] = classCount.get(vote, 0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) # 根据出现次数排序 return sortedClassCount[0][0]def createTree(dataSet, labels): """ :param dataSet: 数据集 :param labels: 标签列表, 包含了数据集中的所有特征的标签 :return: """ classList = [example[-1] for example in dataSet] # 类别完全相同则停止继续划分 if classList.count(classList[0]) == len(classList): return classList[0] # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别 if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) # 选取的最好特征 bestFeatLabel = labels[bestFeat] # 最好特征标签 myTree = {bestFeatLabel:{ }} # 使用字典存储树信息 # 得到列表包含的所有属性值 del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] # 因为下一步传参数时是引用传参 myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTreedef getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): # 测试节点的数据类型是否为字典 if type(secondDict[key]).__name__ == 'dict': numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef createDataSet(): # 创建数据集 dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] return dataSet, labelsdef classify(inputTree, featLabels, testVec): # 分类器 """ :param inputTree: 树,即数据集 :param featLabels: 特征标签 :param testVec: 待测向量 :return: 类别 """ firstStr = list(inputTree.keys())[0] # 将标签字符串转换为索引 secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) # 若未到叶子节点,则继续往下递归,直到叶子节点 else: classLabel = secondDict[key] # 如果已到叶子节点, 则直接取dict当前key的value return classLabeldef storeTree(inputTree, filename): # 序列化保存树(分类信息) import pickle fw = open(filename, 'wb+') pickle.dump(inputTree, fw) fw.close()def grabTree(filename): # 读取序列化文件 import pickle fr = open(filename, "rb+") return pickle.load(fr)if __name__ == "__main__": myDat, labels = createDataSet() # myTree = createTree(myDat, labels) # print(myTree) print(myDat) myTree = treePlotter.retrieveTree(0) print(myTree) print(classify(myTree, labels, [1, 0])) print(classify(myTree, labels, [1, 1])) print("===========store tree============") storeTree(myTree, 'classifierStorafe.txt') print(grabTree('classifierStorafe.txt'))
treePlotter
import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8") # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-") # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解 createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree): fig = plt.figure(1, facecolor='white') # 设置背景色 fig.clf() # 清空画布 axprops = dict(xticks=[], yticks=[]) createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框 plotTree.totalW = float(trees.getNumLeafs(inTree)) plotTree.totalD = float(trees.getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), ' ') # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode) # 第一个坐标是注解的坐标 第二个坐标是点的坐标 # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode) plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息 xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt): #计算宽与高 numLeafs = trees.getNumLeafs(myTree) depth = trees.getTreeDepth(myTree) firstStr = list(myTree.keys())[0] # 根节点 cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) # 标记子节点属性值 plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] # 减少y偏移 plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树 listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}] return listOfTrees[i]if __name__ == "__main__": # reTree = retrieveTree(1) # leafs = trees.getNumLeafs(reTree) # depth = trees.getTreeDepth(reTree) # print(reTree) # print(leafs) # print(depth) myTree = retrieveTree(0) myTree['no surfacing'][3] = 'maybe' createPlot(myTree)
阅读全文
0 0
- 机器学习python实战——决策树
- 决策树——python(机器学习实战)
- 《机器学习实战》—决策树
- PYTHON机器学习实战——决策树DT
- Python《机器学习实战》读书笔记(三)——决策树
- 机器学习实战——决策树
- 《机器学习实战》——决策树
- 《机器学习实战》——决策树代码
- 《机器学习实战》——决策树
- 机器学习实战——决策树讲解
- 机器学习实战——决策树
- 机器学习实战——决策树
- 机器学习实战—决策树(二)
- Python机器学习实战(二)--决策树
- Python机器学习实战之决策树分类
- python机器学习实战2:实现决策树
- 机器学习实战-决策树ID3-python代码
- Python机器学习算法——决策树
- CCNA一些基本命令
- 匹配字符串-正则表达式
- 小型Web应用扫描工具Grabber
- Java培训第04天 Java基础知识(三)---2017年07月13日
- Vue-webpack环境的搭建及项目的创建
- 决策树——python(机器学习实战)
- burpsuite学习——扫描漏洞
- java程序员的python之路(异常)
- 惟伊(成都)日用品有限公司揭牌仪式圆满落幕
- 单级与多级放大器比较
- PAT乙级(Basic)题库---1007
- oracle hint
- socket网络编程基础(乒乓球原理)
- vue中遇到的坑 --- 变化检测问题(数组相关)