【机器学习】决策树(Decision Tree) 学习笔记

来源:互联网 发布:亚麻籽粉 知乎 编辑:程序博客网 时间:2024/06/13 16:11

【机器学习】决策树(decision tree) 学习笔记

标签(空格分隔): 机器学习


决策树简介

决策树(decision tree)是一个树结构(可以是二叉树或非二叉树)。其每个非叶节点表示一个特征属性上的测试,每个分支代表这个特征属性在某个值域上的输出,而每个叶节点存放一个类别。使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。

本文采用的是ID3算法,ID3算法就是在每次需要分裂时,计算每个属性的增益率,然后选择增益率最大的属性进行分裂。

更为详细的介绍见这个博客:算法杂货铺——分类算法之决策树(Decision tree)
以及这个博客:机器学习——决策树算法原理及案例
这个博客的内容来自《机器学习实战》一书。

这个博客主要讲解决策树的python实现,把每行的代码都弄明白。

决策树代码实现

下面的代码分为两个问价:tree.py和treePlotter.py。tree.py包含了计算香农信息增益,分割数据集,选择最佳特征,表决叶节点的标签,创建树,对测试集数据做分类,存储树,读取树,以及一个对隐形眼镜进行分类的例子代码。treePlotter.py是把决策树画出来的代码。
tree.py

# coding=utf-8from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet):    """    计算香农信息增益    :param dataSet:输入的数据集    :return: 熵    """    numEntries = len(dataSet)  # 数据集实例总数    labelCounts = {}  # 数据字典,键值是最后一列的数值,记录当前类别出现的次数    for featVec in dataSet:  # 对于每个数据进行循环        currentLabel = featVec[-1]  # 最后一列        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1  # 统计这个标签出现的次数    shannonEnt = 0.0  # 香农信息增益    for key in labelCounts:  # 对于每个标签        prob = float(labelCounts[key]) / numEntries  # 获取标签出现的概率        shannonEnt -= prob * log(prob, 2)  # 信息增益-=xi出现的概率*log2(xi出现的概率)    return shannonEntdef createDataSet():    """    创造数据集    :return:数据集,标签    """    dataSet = [[1, 1, 'yes'],               [1, 1, 'yes'],               [1, 0, 'no'],               [0, 1, 'no'],               [0, 1, 'no']]    labels = ['no surfacing', 'flippers']    # change to discrete values    return dataSet, labelsdef splitDataSet(dataSet, axis, value):    """    划分数据集    :param dataSet:带划分的数据集    :param axis: 划分数据集的特征    :param value: 需要返回的特征的值    :return:    """    retDataSet = []    for featVec in dataSet:  # 遍历数据集中的每一组数据        if featVec[axis] == value:  # 该组数据符合特征            reducedFeatVec = featVec[:axis]  # 截取该组数据的前半段            reducedFeatVec.extend(featVec[axis + 1:])  # 截取数据的后半段            # 这样两次操作删除了以axis为下标的元素            # 不能直接删除,否则影响原始dataSet            retDataSet.append(reducedFeatVec)  # 返回的数据集添加上满足条件的数据组去除了特征的数据组    return retDataSetdef chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0]) - 1  # 最后一列是标签,不是特征    baseEntropy = calcShannonEnt(dataSet)  # 计算原始香农增益    bestInfoGain = 0.0  # 最佳信息增益    bestFeature = -1  # 最好的特征    for i in range(numFeatures):  # iterate over all the features        featList = [example[i] for example in dataSet]  # create a list of all the examples of this feature        uniqueVals = set(featList)  # get a set of unique values        print "uniqueVals", uniqueVals        newEntropy = 0.0  # 对于此特征的熵        for value in uniqueVals:  # 遍历此特征所有的唯一属性值            print "value", value            subDataSet = splitDataSet(dataSet, i, value)  # 按照这个唯一属性值划分数据            print "subDataSet", subDataSet            prob = len(subDataSet) / float(len(dataSet))  # 这个唯一属性值出现的概率            print "prob", prob            newEntropy += prob * calcShannonEnt(subDataSet)  # 对所有唯一属性值得到的熵求和            print "newEntropy", newEntropy        infoGain = baseEntropy - newEntropy  # calculate the info gain; ie reduction in entropy        print "infoGain", infoGain        if (infoGain > bestInfoGain):  # compare this to the best gain so far            bestInfoGain = infoGain  # if better than current best, set to best            print "bestInfoGain", bestInfoGain            bestFeature = i    return bestFeature  # returns an integerdef majorityCnt(classList):    """    如果所有属性都参与了划分,但类标签依然不是唯一的,定义叶子节点的方法    :param classList: 叶子节点的所有标签    :return: 该叶子节点的标签定义    """    classCount = {}  # 叶子节点的统计    for vote in classList:  # 投票表决        if vote not in classCount.keys(): classCount[vote] = 0  # 如果没有该类标签就初始化为0        classCount[vote] += 1  # 类标签个数加一    # 也可以用下面代码代替上面两行    # classCount[vote] = classCount.get(vote, 0) + 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    print "sortedClassCount", sortedClassCount    # 按照类标签个数排序    return sortedClassCount[0][0]  # 返回个数最多的标签名称def createTree(dataSet, labels):    """    创建树    :param dataSet: 数据集    :param labels: 标签列表,其实用不到    :return:    """    classList = [example[-1] for example in dataSet]  # 所有类别标签    print "classList", classList    if classList.count(classList[0]) == len(classList):  # 判断类标签全部相同        return classList[0]  # stop splitting when all of the classes are equal    if len(dataSet[0]) == 1:  # stop splitting when there are no more features in dataSet        return majorityCnt(classList)  # 已无法再使用特征分类,用标签的大多数代表这个节点    bestFeat = chooseBestFeatureToSplit(dataSet)  # 选择最佳分类标签的序号    print "bestFeat", bestFeat    bestFeatLabel = labels[bestFeat]  # 最佳分类标签    print "bestFeatLabel", bestFeatLabel    myTree = {bestFeatLabel: {}}  # 保存树的所有信息    del (labels[bestFeat])  # 删除标签列表中的最佳标签    featValues = [example[bestFeat] for example in dataSet]  # 最佳标签对应的所有特征值    print "featValues", featValues    uniqueVals = set(featValues)  # 把最佳标签对应的所有特征值去重    print "uniqueVals", uniqueVals    for value in uniqueVals:  # 对于每个唯一的最佳标签对应的所有特征值        subLabels = labels[:]  # copy all of labels, so trees don't mess up existing labels        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTreedef classify(inputTree, featLabels, testVec):    """    使用决策树的分类函数    :param inputTree:输入的树    :param featLabels:特征标签    :param testVec:要进行分类的向量    :return:    """    firstStr = inputTree.keys()[0]  # 输入树的第一个分类标签字符串    print "firstStr", firstStr    secondDict = inputTree[firstStr]  # 标签字符串指向的树    print "secondDict", secondDict    featIndex = featLabels.index(firstStr)  # 将标签字符串转换为索引    print "featIndex", featIndex    key = testVec[featIndex]  # 找出测试的向量此索引下的值    print "key", key    valueOfFeat = secondDict[key]  # 根据索引下的值找出下一个子树    print "valueOfFeat", valueOfFeat    if isinstance(valueOfFeat, dict):  # 循环判断是否已经到了叶节点        classLabel = classify(valueOfFeat, featLabels, testVec)  # 不是叶子节点,分类标签继续循环    else:        classLabel = valueOfFeat  # 已经到了叶节点    return classLabel  # 返回最后预测的分类标签def storeTree(inputTree, filename):    """    存储决策树    :param inputTree:要保存的决策树    :param filename:保存的文件名    :return:    """    import pickle    fw = open(filename, 'w')  # 文件写    pickle.dump(inputTree, fw)  # 把决策树对象序列化写    fw.close()  # 关闭文件操作def grabTree(filename):    """    从磁盘上读取决策树    :param filename:文件名字    :return: 决策树    """    import pickle    fr = open(filename)    return pickle.load(fr)dataSet, labels = createDataSet()print "dataSet", dataSetmyTree = treePlotter.retrieveTree(0)print "myTree", myTreetreePlotter.createPlot(myTree)print classify(myTree, labels, [1, 0])storeTree(myTree, 'classifierStorage.txt')print grabTree('classifierStorage.txt')

treePlotter.py主要是画图功能。

import matplotlib.pyplot as pltdecisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def getNumLeafs(myTree):    numLeafs = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth: maxDepth = thisDepth    return maxDepthdef plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):  # if the first key tells you what feat was split on    numLeafs = getNumLeafs(myTree)  # this determines the x width of this tree    depth = getTreeDepth(myTree)    firstStr = myTree.keys()[0]  # the text label for this node should be this    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD    for key in secondDict.keys():        if type(secondDict[                    key]).__name__ == 'dict':  # test to see if the nodes are dictonaires, if not they are leaf nodes            plotTree(secondDict[key], cntrPt, str(key))  # recursion        else:  # it's a leaf node print the leaf node            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD# if you do get a dictonary you know it's a tree, and the first element will be another dictdef createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # no ticks    # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5 / plotTree.totalW;    plotTree.yOff = 1.0;    plotTree(inTree, (0.5, 1.0), '')    plt.show()# def createPlot():#    fig = plt.figure(1, facecolor='white')#    fig.clf()#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses #    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)#    plt.show()def retrieveTree(i):    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                   {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                   ]    return listOfTrees[i]    # createPlot(thisTree)

为了让大家更明白整个过程的运行结果,可以看下面的输出数据。

dataSet [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]myTree {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}firstStr no surfacingsecondDict {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}featIndex 0key 1valueOfFeat {'flippers': {0: 'no', 1: 'yes'}}firstStr flipperssecondDict {0: 'no', 1: 'yes'}featIndex 1key 0valueOfFeat nono{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

画出的决策树:
决策树样例

决策树实战 使用决策树预测隐形眼镜类型

数据集是这个lenses.txt:

young   myope   no  reduced no lensesyoung   myope   no  normal  softyoung   myope   yes reduced no lensesyoung   myope   yes normal  hardyoung   hyper   no  reduced no lensesyoung   hyper   no  normal  softyoung   hyper   yes reduced no lensesyoung   hyper   yes normal  hardpre myope   no  reduced no lensespre myope   no  normal  softpre myope   yes reduced no lensespre myope   yes normal  hardpre hyper   no  reduced no lensespre hyper   no  normal  softpre hyper   yes reduced no lensespre hyper   yes normal  no lensespresbyopic  myope   no  reduced no lensespresbyopic  myope   no  normal  no lensespresbyopic  myope   yes reduced no lensespresbyopic  myope   yes normal  hardpresbyopic  hyper   no  reduced no lensespresbyopic  hyper   no  normal  softpresbyopic  hyper   yes reduced no lensespresbyopic  hyper   yes normal  no lenses

下面的代码就是通过上文的决策树算法实现了预测,并且画出了具体的决策树的结构图。

def classifyLenses():    """    分类隐形眼镜    :return:    """    fr = open('lenses.txt')    lenses = [inst.strip().split('\t') for inst in fr.readlines()]    print "lenses", lenses    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']    lensesTree = createTree(lenses, lensesLabels)    print "lensesTree", lensesTree    treePlotter.createPlot(lensesTree)classifyLenses()

画出来的决策树的结构图如下。

隐形眼镜的分类预测

决策树算法在做分类时同样存在问题。比如过度匹配,ID3算法可以用于划分标称数据集,无法直接处理数值型数据。

这篇博客是对《机器学习实战》一书的学习笔记,如有不明白之处,请阅读该书。

0 0
原创粉丝点击