决策树——python(机器学习实战)

来源:互联网 发布:centos git 编译 编辑:程序博客网 时间:2024/04/30 03:16

    • 原理
    • 步骤分解
      • 遍历数据集 循环计算提取每个特征的香农熵和信息增益 选取信息增益最大的特征 再递归计算剩余的特征顺序 将特征排序 并将分类结果序列化保存到磁盘当中
      • 递归构建决策树
      • 使用Matplotlib注解绘制树形图
    • 完整代码

原理

决策树图示
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类

步骤分解

1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中

def chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征    """    :param dataSet: 原数据集    :return: 最好的划分特征的索引值    """    numFeatures = len(dataSet[0]) - 1   # 获取特征数    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵    bestInfoGain = 0.0      # 初始化最好的信息熵    bestFeature = -1        # 初始化最好的用于分割的特征    for i in range(numFeatures):        # 创建唯一的分类标签列表        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)        newEntropy = 0.0        # 计算每种划分方式的信息熵        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大        # 计算最好的信息增益        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换            bestInfoGain = infoGain            bestFeature = i     # 表示最好的划分特征的索引值    return bestFeature

2.递归构建决策树

def createTree(dataSet, labels):    """    :param dataSet: 数据集    :param labels: 标签列表, 包含了数据集中的所有特征的标签    :return:    """    classList = [example[-1] for example in dataSet]    # 类别完全相同则停止继续划分    if classList.count(classList[0]) == len(classList):        return classList[0]    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征    bestFeatLabel = labels[bestFeat]    # 最好特征标签    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息    # 得到列表包含的所有属性值    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]   # 因为下一步传参数时是引用传参        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTree

3.使用Matplotlib注解绘制树形图

import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree):    fig = plt.figure(1, facecolor='white')  # 设置背景色    fig.clf()   # 清空画布    axprops = dict(xticks=[], yticks=[])    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框    plotTree.totalW = float(trees.getNumLeafs(inTree))    plotTree.totalD = float(trees.getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), ' ')    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)    plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高    numLeafs = trees.getNumLeafs(myTree)    depth = trees.getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]   # 根节点    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    # 标记子节点属性值    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    # 减少y偏移    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]    return listOfTrees[i]

完整代码

trees.py

from math import logimport operatorimport treePlotterdef calcShannonEnt(dataSet):    # 计算给定数据集的香农熵    numEntries = len(dataSet)    labelCounts = {}    # 为所有可能的分类创建字典    for featVec in dataSet:        currentLabel = featVec[-1]        # if currentLabel not in labelCounts.keys():        #     labelCounts[currentLabel] = 0        # labelCounts[currentLabel] += 1        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1    shannonEnt = 0.0    for key in labelCounts:        # 以2为底求对数        prob = float(labelCounts[key]) / numEntries        shannonEnt -= prob * log(prob, 2)    return shannonEntdef splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集    """    :param dataSet: 待划分的数据集    :param axis: 划分数据集的特征    :param value: 特征的返回值    :return:    """    # 创建新的list对象    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:  # 抽取            reducedFratVec = featVec[:axis]            reducedFratVec.extend(featVec[axis+1:])            retDataSet.append(reducedFratVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征    """    :param dataSet: 原数据集    :return: 最好的划分特征的索引值    """    numFeatures = len(dataSet[0]) - 1   # 获取特征数    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵    bestInfoGain = 0.0      # 初始化最好的信息熵    bestFeature = -1        # 初始化最好的用于分割的特征    for i in range(numFeatures):        # 创建唯一的分类标签列表        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)        newEntropy = 0.0        # 计算每种划分方式的信息熵        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大        # 计算最好的信息增益        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换            bestInfoGain = infoGain            bestFeature = i     # 表示最好的划分特征的索引值    return bestFeaturedef majorityCnt(classList): # 多数表决决定叶子节点的分类    """    :param classList: 类别列表    :return: 出现次数最多的分类名称    """    classCount = {}    for vote in classList:  # 统计分类列表中个类别出现的次数        # if vote not in classCount.keys(): classCount[vote] = 0        # classCount[vote] += 1        classCount[vote] = classCount.get(vote, 0) + 1    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 根据出现次数排序    return sortedClassCount[0][0]def createTree(dataSet, labels):    """    :param dataSet: 数据集    :param labels: 标签列表, 包含了数据集中的所有特征的标签    :return:    """    classList = [example[-1] for example in dataSet]    # 类别完全相同则停止继续划分    if classList.count(classList[0]) == len(classList):        return classList[0]    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征    bestFeatLabel = labels[bestFeat]    # 最好特征标签    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息    # 得到列表包含的所有属性值    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]   # 因为下一步传参数时是引用传参        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)    return myTreedef getNumLeafs(myTree):    numLeafs = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        # 测试节点的数据类型是否为字典        if type(secondDict[key]).__name__ == 'dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            thisDepth = 1 + getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepthdef createDataSet():    # 创建数据集    dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]    labels = ['no surfacing', 'flippers']    return dataSet, labelsdef classify(inputTree, featLabels, testVec):   # 分类器    """    :param inputTree: 树,即数据集    :param featLabels: 特征标签    :param testVec: 待测向量    :return: 类别    """    firstStr = list(inputTree.keys())[0]    # 将标签字符串转换为索引    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__ == 'dict':                classLabel = classify(secondDict[key], featLabels, testVec)     # 若未到叶子节点,则继续往下递归,直到叶子节点            else:                classLabel = secondDict[key]        # 如果已到叶子节点, 则直接取dict当前key的value    return classLabeldef storeTree(inputTree, filename):     # 序列化保存树(分类信息)    import pickle    fw = open(filename, 'wb+')    pickle.dump(inputTree, fw)    fw.close()def grabTree(filename):     # 读取序列化文件    import pickle    fr = open(filename, "rb+")    return pickle.load(fr)if __name__ == "__main__":    myDat, labels = createDataSet()    # myTree = createTree(myDat, labels)    # print(myTree)    print(myDat)    myTree = treePlotter.retrieveTree(0)    print(myTree)    print(classify(myTree, labels, [1, 0]))    print(classify(myTree, labels, [1, 1]))    print("===========store tree============")    storeTree(myTree, 'classifierStorafe.txt')    print(grabTree('classifierStorafe.txt'))

treePlotter

import matplotlib.pyplot as pltimport trees# 定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth", fc="0.8")leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',                            xytext=centerPt, textcoords='axes fraction',                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot(inTree):    fig = plt.figure(1, facecolor='white')  # 设置背景色    fig.clf()   # 清空画布    axprops = dict(xticks=[], yticks=[])    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框    plotTree.totalW = float(trees.getNumLeafs(inTree))    plotTree.totalD = float(trees.getTreeDepth(inTree))    plotTree.xOff = -0.5/plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), ' ')    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)    plt.show()def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]    createPlot.axl.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高    numLeafs = trees.getNumLeafs(myTree)    depth = trees.getTreeDepth(myTree)    firstStr = list(myTree.keys())[0]   # 根节点    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    # 标记子节点属性值    plotMidText(cntrPt, parentPt, nodeTxt)    plotNode(firstStr, cntrPt, parentPt, decisionNode)    secondDict = myTree[firstStr]    # 减少y偏移    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef retrieveTree(i): # 创建树    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]    return listOfTrees[i]if __name__ == "__main__":    # reTree = retrieveTree(1)    # leafs = trees.getNumLeafs(reTree)    # depth = trees.getTreeDepth(reTree)    # print(reTree)    # print(leafs)    # print(depth)    myTree = retrieveTree(0)    myTree['no surfacing'][3] = 'maybe'    createPlot(myTree)
原创粉丝点击