【机器学习】决策树(Decision Tree) 学习笔记
来源:互联网 发布:亚麻籽粉 知乎 编辑:程序博客网 时间:2024/06/13 16:11
【机器学习】决策树(decision tree) 学习笔记
标签(空格分隔): 机器学习
决策树简介
决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
本文采用的是ID3算法,ID3算法就是在每次需要分裂时,计算每个属性的增益率,然后选择增益率最大的属性进行分裂。
更为详细的介绍见这个博客:算法杂货铺——分类算法之决策树(Decision tree)
以及这个博客:机器学习——决策树算法原理及案例
这个博客的内容来自《机器学习实战》一书。
这个博客主要讲解决策树的python实现,把每行的代码都弄明白。
决策树代码实现
下面的代码分为两个问价:tree.py和treePlotter.py。tree.py包含了计算香农信息增益,分割数据集,选择最佳特征,表决叶节点的标签,创建树,对测试集数据做分类,存储树,读取树,以及一个对隐形眼镜进行分类的例子代码。treePlotter.py是把决策树画出来的代码。
tree.py
# coding=utf-8from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet): """ 计算香农信息增益 :param dataSet:输入的数据集 :return: 熵 """ numEntries = len(dataSet) # 数据集实例总数 labelCounts = {} # 数据字典,键值是最后一列的数值,记录当前类别出现的次数 for featVec in dataSet: # 对于每个数据进行循环 currentLabel = featVec[-1] # 最后一列 labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1 # 统计这个标签出现的次数 shannonEnt = 0.0 # 香农信息增益 for key in labelCounts: # 对于每个标签 prob = float(labelCounts[key]) / numEntries # 获取标签出现的概率 shannonEnt -= prob * log(prob, 2) # 信息增益-=xi出现的概率*log2(xi出现的概率) return shannonEntdef createDataSet(): """ 创造数据集 :return:数据集,标签 """ dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing', 'flippers'] # change to discrete values return dataSet, labelsdef splitDataSet(dataSet, axis, value): """ 划分数据集 :param dataSet:带划分的数据集 :param axis: 划分数据集的特征 :param value: 需要返回的特征的值 :return: """ retDataSet = [] for featVec in dataSet: # 遍历数据集中的每一组数据 if featVec[axis] == value: # 该组数据符合特征 reducedFeatVec = featVec[:axis] # 截取该组数据的前半段 reducedFeatVec.extend(featVec[axis + 1:]) # 截取数据的后半段 # 这样两次操作删除了以axis为下标的元素 # 不能直接删除,否则影响原始dataSet retDataSet.append(reducedFeatVec) # 返回的数据集添加上满足条件的数据组去除了特征的数据组 return retDataSetdef chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 # 最后一列是标签,不是特征 baseEntropy = calcShannonEnt(dataSet) # 计算原始香农增益 bestInfoGain = 0.0 # 最佳信息增益 bestFeature = -1 # 最好的特征 for i in range(numFeatures): # iterate over all the features featList = [example[i] for example in dataSet] # create a list of all the examples of this feature uniqueVals = set(featList) # get a set of unique values print "uniqueVals", uniqueVals newEntropy = 0.0 # 对于此特征的熵 for value in uniqueVals: # 遍历此特征所有的唯一属性值 print "value", value subDataSet = splitDataSet(dataSet, i, value) # 按照这个唯一属性值划分数据 print "subDataSet", subDataSet prob = len(subDataSet) / float(len(dataSet)) # 这个唯一属性值出现的概率 print "prob", prob newEntropy += prob * calcShannonEnt(subDataSet) # 对所有唯一属性值得到的熵求和 print "newEntropy", newEntropy infoGain = baseEntropy - newEntropy # calculate the info gain; ie reduction in entropy print "infoGain", infoGain if (infoGain > bestInfoGain): # compare this to the best gain so far bestInfoGain = infoGain # if better than current best, set to best print "bestInfoGain", bestInfoGain bestFeature = i return bestFeature # returns an integerdef majorityCnt(classList): """ 如果所有属性都参与了划分,但类标签依然不是唯一的,定义叶子节点的方法 :param classList: 叶子节点的所有标签 :return: 该叶子节点的标签定义 """ classCount = {} # 叶子节点的统计 for vote in classList: # 投票表决 if vote not in classCount.keys(): classCount[vote] = 0 # 如果没有该类标签就初始化为0 classCount[vote] += 1 # 类标签个数加一 # 也可以用下面代码代替上面两行 # classCount[vote] = classCount.get(vote, 0) + 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) print "sortedClassCount", sortedClassCount # 按照类标签个数排序 return sortedClassCount[0][0] # 返回个数最多的标签名称def createTree(dataSet, labels): """ 创建树 :param dataSet: 数据集 :param labels: 标签列表,其实用不到 :return: """ classList = [example[-1] for example in dataSet] # 所有类别标签 print "classList", classList if classList.count(classList[0]) == len(classList): # 判断类标签全部相同 return classList[0] # stop splitting when all of the classes are equal if len(dataSet[0]) == 1: # stop splitting when there are no more features in dataSet return majorityCnt(classList) # 已无法再使用特征分类,用标签的大多数代表这个节点 bestFeat = chooseBestFeatureToSplit(dataSet) # 选择最佳分类标签的序号 print "bestFeat", bestFeat bestFeatLabel = labels[bestFeat] # 最佳分类标签 print "bestFeatLabel", bestFeatLabel myTree = {bestFeatLabel: {}} # 保存树的所有信息 del (labels[bestFeat]) # 删除标签列表中的最佳标签 featValues = [example[bestFeat] for example in dataSet] # 最佳标签对应的所有特征值 print "featValues", featValues uniqueVals = set(featValues) # 把最佳标签对应的所有特征值去重 print "uniqueVals", uniqueVals for value in uniqueVals: # 对于每个唯一的最佳标签对应的所有特征值 subLabels = labels[:] # copy all of labels, so trees don't mess up existing labels myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTreedef classify(inputTree, featLabels, testVec): """ 使用决策树的分类函数 :param inputTree:输入的树 :param featLabels:特征标签 :param testVec:要进行分类的向量 :return: """ firstStr = inputTree.keys()[0] # 输入树的第一个分类标签字符串 print "firstStr", firstStr secondDict = inputTree[firstStr] # 标签字符串指向的树 print "secondDict", secondDict featIndex = featLabels.index(firstStr) # 将标签字符串转换为索引 print "featIndex", featIndex key = testVec[featIndex] # 找出测试的向量此索引下的值 print "key", key valueOfFeat = secondDict[key] # 根据索引下的值找出下一个子树 print "valueOfFeat", valueOfFeat if isinstance(valueOfFeat, dict): # 循环判断是否已经到了叶节点 classLabel = classify(valueOfFeat, featLabels, testVec) # 不是叶子节点,分类标签继续循环 else: classLabel = valueOfFeat # 已经到了叶节点 return classLabel # 返回最后预测的分类标签def storeTree(inputTree, filename): """ 存储决策树 :param inputTree:要保存的决策树 :param filename:保存的文件名 :return: """ import pickle fw = open(filename, 'w') # 文件写 pickle.dump(inputTree, fw) # 把决策树对象序列化写 fw.close() # 关闭文件操作def grabTree(filename): """ 从磁盘上读取决策树 :param filename:文件名字 :return: 决策树 """ import pickle fr = open(filename) return pickle.load(fr)dataSet, labels = createDataSet()print "dataSet", dataSetmyTree = treePlotter.retrieveTree(0)print "myTree", myTreetreePlotter.createPlot(myTree)print classify(myTree, labels, [1, 0])storeTree(myTree, 'classifierStorage.txt')print grabTree('classifierStorage.txt')
treePlotter.py主要是画图功能。
import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafsdef getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt): # if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) # this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] # the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': # test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key], cntrPt, str(key)) # recursion else: # it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # no ticks # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5, 1.0), '') plt.show()# def createPlot():# fig = plt.figure(1, facecolor='white')# fig.clf()# createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses # plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)# plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)# plt.show()def retrieveTree(i): listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return listOfTrees[i] # createPlot(thisTree)
为了让大家更明白整个过程的运行结果,可以看下面的输出数据。
dataSet [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]myTree {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}firstStr no surfacingsecondDict {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}featIndex 0key 1valueOfFeat {'flippers': {0: 'no', 1: 'yes'}}firstStr flipperssecondDict {0: 'no', 1: 'yes'}featIndex 1key 0valueOfFeat nono{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
画出的决策树:
决策树实战 使用决策树预测隐形眼镜类型
数据集是这个lenses.txt:
young myope no reduced no lensesyoung myope no normal softyoung myope yes reduced no lensesyoung myope yes normal hardyoung hyper no reduced no lensesyoung hyper no normal softyoung hyper yes reduced no lensesyoung hyper yes normal hardpre myope no reduced no lensespre myope no normal softpre myope yes reduced no lensespre myope yes normal hardpre hyper no reduced no lensespre hyper no normal softpre hyper yes reduced no lensespre hyper yes normal no lensespresbyopic myope no reduced no lensespresbyopic myope no normal no lensespresbyopic myope yes reduced no lensespresbyopic myope yes normal hardpresbyopic hyper no reduced no lensespresbyopic hyper no normal softpresbyopic hyper yes reduced no lensespresbyopic hyper yes normal no lenses
下面的代码就是通过上文的决策树算法实现了预测,并且画出了具体的决策树的结构图。
def classifyLenses(): """ 分类隐形眼镜 :return: """ fr = open('lenses.txt') lenses = [inst.strip().split('\t') for inst in fr.readlines()] print "lenses", lenses lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] lensesTree = createTree(lenses, lensesLabels) print "lensesTree", lensesTree treePlotter.createPlot(lensesTree)classifyLenses()
画出来的决策树的结构图如下。
决策树算法在做分类时同样存在问题。比如过度匹配,ID3算法可以用于划分标称数据集,无法直接处理数值型数据。
这篇博客是对《机器学习实战》一书的学习笔记,如有不明白之处,请阅读该书。
- 【机器学习】决策树(Decision Tree) 学习笔记
- 【机器学习】决策树(Decision Tree)
- 机器学习: 决策树(Decision Tree)
- 机器学习:决策树(Decision Tree)
- 机器学习之:决策树(Decision Tree)
- 决策树(Decision Tree)-机器学习ML
- 学习笔记35-决策树(Decision Tree)
- 机器学习之决策树(Decision Tree)
- 机器学习算法实践:决策树 (Decision Tree)
- 机器学习(三)决策树算法Decision Tree
- 机器学习---决策树(decision tree)算法
- 机器学习---决策树decision tree的应用
- 机器学习系列05——决策树(Decision tree)
- [完]机器学习实战 第三章 决策树(Decision Tree)
- 机器学习算法—决策树(Decision Tree)
- 机器学习之决策树 Decision Tree(一)
- 机器学习之决策树 Decision Tree(二)Python实现
- 机器学习笔记-Decision Tree
- npm 安装 bin/npm: line 1: ../lib/node_modules/npm/bin/npm-cli.js: No such file or directory
- 陷阱还是馅饼!聊聊企业模板建站的那些事
- 带分数
- JSR 356 WebSocket (Java WebSocket 1.0) support is not available when running on Java 6. To suppress
- 泛型
- 【机器学习】决策树(Decision Tree) 学习笔记
- 润乾数据报表基础操作
- h2159动态规划
- 动态规划之背包问题01--java实现
- 学习笔记整理——vim,vim文件编码,vim配置文件
- java中递归实现复制多级文件夹
- mybatis五表连接
- requests库
- TCP_FIN_WAIT1可以等多久