机器学习002—决策树

来源:互联网 发布:windows清理助手64 编辑:程序博客网 时间:2024/06/05 07:08

决策树

1 决策树概念

A decision tree is a decision support tool that uses a tree-like graph or model of decisions and their possible consequences, including chance event outcomes, resource costs, and utility. It is one way to display an algorithm.
Decision trees are commonly used in operations research, specifically in decision analysis, to help identify a strategy most likely to reach a goal, but are also a popular tool in machine learning.

决策树是使用树状图或决策模型的一个决策支持工具,其可能的产生的效果,包括机会事件结果,资源成本和效用。 它是显示算法的一种方式。


1.1 邮件处理系统的效率可以由决策树来表示:

这里写图片描述
上图显示的是:根据一些特征来分类,看是不是迫切需要处理的还是需要处理的。

猜题游戏:参与游戏的一方可以确定一个答案,另一个人提问20个问题来确定答案,问题的答案只能用对错来回答,也可以使用决策树。

平常许多情况都需要用到决策树,决策树是最经常使用的数据挖掘算法。决策时不一定非要二叉树,多叉树也是可以的,每个结点显示的信息也不一定只能是一个,可以多条信息显示在一起。

决策树将一些事情很直观地显示出来,使其处理问题更加地简单。


1.2 决策树特点

  • Pros:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不想管特征数据
  • Cons:可能会产生过度匹配问题
  • Works with:数值型和标称型

2 决策树构造

决策树分支构造步骤如下:
1. 检测数据集中得每个子项是否属于同一分类,“是”则返回节点并结束,“否”则2
2. 寻找划分数据集的最好特征,划分数据集,创建分支节点
3. 对每个划分的子集,进行1操作

创建分支的伪代码函数createBranch():

检测数据集中得每个子项是否属于同一类if so return 类标签ELSE     寻找划分数据集的最好特征     划分数据集     创建分支节点           for 每个划分的子集           调用函数createBranch()并增加返回结果到分支节点中     return 分支节点

决策树算法可以采用二分法、ID3算法划分数据集

3 信息熵、信息增益以及基尼指数

3.1 信息熵

  • 集合信息的度量方式称为香农熵或者熵
  • 熵是对信息不确定的度量
  • 熵定义为信息的期望值
  • 一个系统越有序,则信息熵越低,相反一个系统越是混乱,则它的信息熵越高。

如果待分类的事物可能划分在多个分类之中,则符号xi的信息定义为:这里写图片描述
计算熵时我们需要计算所有类别所有可能值包含的信息期望值:
这里写图片描述

n为分类的数目


3.2 信息增益

  • 划分数据集的大原则是:将无序的数据变得更加地有序。
  • 在划分数据集之前之后信息发生的变化称为信息增益

以天气预报的例子来详细说明信息增益的含义
这里写图片描述
学习目标是play或者not play
一共有14个样例,9个正例和5个负例,当前信息的熵计算如下:

Entropy(S) = - 9/14 * log2(9/14) - 5/14 * log2(5/14)

在决策树分类问题中,信息增益就是决策树在进行属性划分前后信息的差值。假设利用属性Outlook来分类,那么如下图
这里写图片描述

划分后,数据被分为三个部分,各个分支的信息熵计算如下:
这里写图片描述

划分后的信息熵为:
这里写图片描述
这里写图片描述

信息增益的计算公式:
这里写图片描述
本例的信息增益:
这里写图片描述

在决策树的每一个非叶子节点划分之前,先计算每一个属性所带来的信息增益,选择最大信息增益的属性来划分,因为信息增益越大,区分样本能力就越强,越具有代表性,这是一种自顶向下的贪心策略。这也是ID3算法的核心思想。

3.3 基尼指数

  • CART算法主要使用Gini指数。
  • 在CART算法中,基尼不纯度表示一个随机选中的样本在子集中被分错的可能性。
  • 基尼不纯度为这个样本被选中的概率乘以它被分错的概率
  • 假设y的可能取值为{1,2,……,m},令fi是样本被赋予i的概率,则基尼指数可以通过如下计算:

这里写图片描述

这里写图片描述

这里写图片描述


4 C4.5算法&&ID3算法&&CART算法

4.1 C4.5算法

ID3算法的思想如上例所示,C4.5算法是机器学习中另一个分类决策树算法,它是基于ID3算法进行改进后的一种重要算法,改进有如下几个要点:

  • 用信息增益率来选择属性。ID3选择属性用的是子树的信息增益,这里可以用很多方法来定义信息,ID3使用的是熵(entropy, 熵是一种不纯度度量准则),也就是熵的变化值,而C4.5用的是信息增益率。
  • 在决策树构造过程中进行剪枝,因为某些具有很少元素的结点可能会使构造的决策树过适应(Overfitting),如果不考虑这些结点可能会更好。
  • 对非离散数据也能处理。
  • 能够对不完整数据进行处理。

上述例子使用C4.5算法:
计算分裂信息度量H(V):

H(Outlook) = - 5/14 * log2(5/14) - 4/14 * log2(4/14) - 5/14 * log2(5/14)

信息增益率:

IGR(Outlook) = Entropy(S|T) / H(Outlook)

4.2 CART算法

天气预报的CART算法的具体计算过程如下:

Outlook sunny overcast rain YES 2 4 3 NO 3 0 2

Gini(Sunny) = 1 - (2/5)^2 - (3/5)^2
Gini(Overcast) =1 - (4/4)^2 - (0/4)^2
Gini(rain) = 1 - (3/5)^2 - (2/5)^2
Gini= 5/14*Gini(Sunny) + 4/14*Gini(Overcast)+5/14*Gini(rain)


对离散值如{x,y,z},则在该属性上的划分有三种情况
这里写图片描述
空集和全集的划分除外

天气预报的例子的计算情况如下:

Outlook sunny or overcast rain YES 6 3 NO 3 2

然后再进行计算


5 海洋生物数据处理

海洋中有5个动物,特征有:不浮出水面是否可以生产,是否有脚蹼,是否属于鱼类,我们可以依赖这些特征进行划分数据,但是要依据哪个特征来划分数据需要计算信息增益。
这里写图片描述

划分数据集的数据路径
这里写图片描述

trees.py的代码:

# _*_ coding: UTF-8 -*-from math import logimport operator'''    输入数据集,这是鱼鉴定数据集'''def createDataSet():    dataSet = [[1, 1, 'yes'],               [1, 1, 'yes'],               [1, 0, 'no'],               [0, 1, 'no'],               [0, 1, 'no']]    labels = ['no surfacing', 'flippers']    return dataSet, labels'''    calcShannonEnt计算香农熵'''def calcShannonEnt(dataSet):    # 计算数据集中实例的总数    numEntries = len(dataSet)    labelCounts = {}    """            为所有可能分类创建字典,它的键值是最后一列的数值,            如果当前键值不存在,则扩展字典并将当前键值加入字典            每个键值都记录了当前类别出现的次数    """    for featVec in dataSet:        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    """            使用所有类标签的发生频率计算类别出现的概率,用这个概率计算香农熵            统计所有类标签发生的次数    """    for key in labelCounts:        prob = float(labelCounts[key]) / numEntries        shannonEnt -= prob * log(prob, 2)  # log base 2    return shannonEnt"""按照给定特征划分数据集"""def splitDataSet(dataSet, axis, value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]  # chop out axis used for splitting            reducedFeatVec.extend(featVec[axis + 1:])            retDataSet.append(reducedFeatVec)    return retDataSet"""    选择最好的数据集划分方式"""def chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0]) - 1  # the last column is used for the labels    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0;    bestFeature = -1    for i in range(numFeatures):  # iterate over all the features        featList = [example[i] for example in dataSet]  # create a list of all the examples of this feature        uniqueVals = set(featList)  # get a set of unique values        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy  # calculate the info gain; ie reduction in entropy        if (infoGain > bestInfoGain):  # compare this to the best gain so far            bestInfoGain = infoGain  # if better than current best, set to best            bestFeature = i    return bestFeature  # returns an integerdef majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]"""    创建树的函数代码"""def createTree(dataSet, labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):        return classList[0]  # stop splitting when all of the classes are equal    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel: {}}    del (labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTreedef classify(inputTree, featLabels, testVec):    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    key = testVec[featIndex]    valueOfFeat = secondDict[key]    if isinstance(valueOfFeat, dict):        classLabel = classify(valueOfFeat, featLabels, testVec)    else:        classLabel = valueOfFeat    return classLabeldef storeTree(inputTree, filename):    import pickle    fw = open(filename, 'w')    pickle.dump(inputTree, fw)    fw.close()def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)

treePlotter.py的代码

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree    depth = getTreeDepth(myTree)    firstStr = myTree.keys()[0]  # the text label for this node should be this    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            plotTree(secondDict[key], cntrPt, str(key))  # recursion        else:  # it's a leaf node print the leaf node            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5 / plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), '')    plt.show()# def createPlot():#    fig = plt.figure(1, facecolor='white')#    fig.clf()#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)#    plt.show()def retrieveTree(i):    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                   ]    return listOfTrees[i]    # createPlot(thisTree)

命令行执行语句:

目录: /Users/shasha/PycharmProjects/shang/trees.pyLast login: Thu Jun 22 12:00:46 on ttys000bogon:~ shasha$ cd PycharmProjects/shangbogon:shang shasha$ pythonPython 2.7.10 (default, Oct 23 2015, 19:19:21) [GCC 4.2.1 Compatible Apple LLVM 7.0.0 (clang-700.0.59.5)] on darwinType "help", "copyright", "credits" or "license" for more information.>>> import trees>>> myDat,labels=trees.createDataSet()>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> trees.calcShannonEnt(myDat)0.9709505944546686>>> myDat[0][-1]='maybe'>>> myDat[[1, 1, 'maybe'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> trees.calcShannonEnt(myDat)1.3709505944546687>>> a=[1,2,3]>>> b=[4,5,6]>>> a.append(b)>>> a[1, 2, 3, [4, 5, 6]]>>> a=[1,2,3]>>> a.extend(b)>>> a[1, 2, 3, 4, 5, 6]>>> reload(trees)<module 'trees' from 'trees.pyc'>>>> myDat,labels=trees.createDataSet()>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> trees.splitDataSet(myDat,0,1)[[1, 'yes'], [1, 'yes'], [0, 'no']]>>> trees.splitDataSet(myDat,0,0)[[1, 'no'], [1, 'no']]>>> reload(trees)<module 'trees' from 'trees.pyc'>>>> myDat,labels=trees.createDataSet()>>> trees.chooseBestFeatureToSplit(myDat)0>>> myDat[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]>>> reload(trees)<module 'trees' from 'trees.pyc'>>>> myDat,labels=trees.createDataSet()>>> myTree = trees.createTree(myDat,labels)>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> reload(treePlotter)<module 'treePlotter' from 'treePlotter.py'>>>> treePlotter.retrieveTree(0){'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> myTree=treePlotter.retrieveTree(0)>>> treePlotter.getNumLeafs(myTree)3>>> treePlotter.getTreeDepth(myTree)2>>> reload(treePlotter)<module 'treePlotter' from 'treePlotter.pyc'>>>> myTree=treePlotter.retrieveTree(0)>>> treePlotter.createPlot(myTree)/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3046: RuntimeWarning: invalid value encountered in double_scalars  ddx = pad_projected * dx / cp_distance/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3047: RuntimeWarning: invalid value encountered in double_scalars  ddy = pad_projected * dy / cp_distance/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3050: RuntimeWarning: invalid value encountered in double_scalars  dx = dx / cp_distance * head_dist/System/Library/Frameworks/Python.framework/Versions/2.7/Extras/lib/python/matplotlib/patches.py:3051: RuntimeWarning: invalid value encountered in double_scalars  dy = dy / cp_distance * head_dist>>> myTree['no surfacing'][3]='maybe'>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}}>>> treePlotter.createPlot(myTree)>>> myDat,labels=trees.createDataSet()>>> labels['no surfacing', 'flippers']>>> myTree=treePlotter.retrieveTree(0)>>> myTree{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> trees.classify(myTree,labels,[1,0])'no'>>> trees.classify(myTree,labels,[1,1])'yes'>>> trees.storeTree(myTree,'classifierStorage.txt')>>> trees.grabTree('classifierStorage.txt'){'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}>>> 

这里写图片描述

这里写图片描述

代码以及使用的数据链接:
决策树