机器学习实战---读书笔记: 第3章 决策树

来源:互联网 发布:印度的种姓制度 知乎 编辑:程序博客网 时间:2024/06/01 12:53

内容来源于书《机器学习实战》

# *-* coding: utf-8 *-* '''<<机器学习实战>> ---读书笔记: 第3章 决策树关键:1 决策树基础知识:决策树任务:理解数据中蕴含的知识,提取规则。应用:专家系统优点:复杂度不高,中间值缺失不敏感缺点:容易产生过拟合适用:数值型和标称型决策树构造过程:找到当前能够最大区分数据的特征,将数据划分;如果划分后的每一类数据都属于同一类,则停止划分并设置类标签;否则,对非同一类的数据集再次重复上述过程。长方形:判断模块,椭圆:终止模块,左右箭头:分支2信息增益:划分前后信息的变化。信息增益最高的特征就是最好选择,信息增益表示数据无序度的减少熵:信息的期望值,熵越大,混合数据越多。实际需要求得使得熵最小的划分特征符号Xi的信息为l(Xi) = -log2 P(Xi) , P(Xi)是选择该分类的概率熵: H = -P(Xi) * log2 P(Xi) i从1到n3选择最好的特征过程:选择最好的特征进行划分数据集。具体过程是:遍历每个特征,收集每个特征所有取值的集合,计算该特征每个取值对应的信息,累加后得到该特征的熵。如果该原始熵-当前熵的结果大于信息增益,更新信息增益,并记录该最好特征4创建决策树如果类别标签相同,直接返回,类别标签。否则,如果所有特征用完,选择次数最多的类别。 这里的类别应该是: yes 或者 no计算能够划分得到最大信息增益的特征,然后获取特征的所有取值,遍历每个取值,递归得对每个取值下的数据集进行划分。构建出: 当前特征对应的映射5决策树的存储构造决策树耗时,每次分类时调用已经构造好的决策树,使用pickle序列化对象。序列化对象:在磁盘保存    #必须以二进制形式保存    fw = open(fileName , "wb" )    # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化    pickle.dump(inputTree ,fw)    fw.close()    fr = open(fileName , "rb")    # pickle.load(file):从文件中读取字符串,重构为原来的python对象    return pickle.load(fr)'''from math import logimport operatorimport matplotlib.pyplot as plt#计算给定数据集的香农熵def calcShannonEnt(dataSet):    rows = len(dataSet)    #统计每个类别出现的概率    labelToCount = dict()    for data in dataSet:        label = data[-1]        if label in labelToCount:            labelToCount[label] += 1        else:            labelToCount[label] = 1    #计算香农熵: H = - P(Xi) * log2 P(Xi)    result = 0.0    for label , count in labelToCount.items():        # //返回整数, /返回浮点数,一般用/        prob = count * 1.0 / rows        result -= prob * log(prob , 2)    return result#按照给定特征划分数据集,实际就是遍历,根据给定的列号,对应的列值,生成除该列以外的划分向量def splitDataSet(dataSet , columnNum , value):    resultDatas = []    for data in dataSet:        if value == data[columnNum]:            front = data[ : columnNum]            back = data[columnNum + 1 : ]            front.extend(back)            resultDatas.append(front)    return resultDatasdef createDataSet():    dataSet = [    [1, 1, 'yes'] ,    [1, 1, 'yes'],    [1, 0 , 'no'],    [0, 1 , 'no'],    [0, 1 , 'no'] ]    labels = ['no surfacting' , 'flippers']    return dataSet , labelsdef calcShannonEnt_test():    dataSet , labels = createDataSet()    result = calcShannonEnt(dataSet)    print(result)'''选择最好的特征进行划分数据集。具体过程是:遍历每个特征,收集每个特征所有取值的集合,计算该特征每个取值对应的信息,累加后得到该特征的熵。如果该原始熵-当前熵的结果大于信息增益,更新信息增益,并记录该最好特征'''def chooseBestFeature(dataSet):    rows = len(dataSet)    featureNum = len(dataSet[0]) - 1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0    bestFrature = -1    #遍历每个特征,对每个特征计算熵    for i in range(featureNum):        features = [ temp[i] for temp in dataSet ]        featureValues = set(features)        #根据特征取值,划分数据集,计算划分后的数据集的熵        newEntropy = 0.0        for value in featureValues:            subDatas = splitDataSet(dataSet , i , value)            prob = float( len(subDatas) / rows )            newEntropy += prob * calcShannonEnt(subDatas)        infoGain = baseEntropy - newEntropy        if (infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFrature = i    return bestFraturedef chooseBestFeature_test():    myDat , labels = createDataSet()    bestFeature = chooseBestFeature(myDat)    print(bestFeature)#统计得到<类别, 出现次数>这样的映射,选择出出现次数最多的类别作为返回def majorityCount(classList):    labelToCount = {}    for vote in classList:        if vote in labelToCount:            labelToCount[vote] += 1        else:            labelToCount[vote] = 1    sortedResult = sorted(labelToCount.items() , key=operator.itemgetter(1) , reversed=True)    return sortedResult[0][0]'''创建决策树:如果类别标签相同,直接返回,类别标签。否则,如果所有特征用完,选择次数最多的类别。 这里的类别应该是: yes 或者 no计算能够划分得到最大信息增益的特征,然后获取特征的所有取值,遍历每个取值,递归得对每个取值下的数据集进行划分。构建出: 当前特征对应的映射'''def createTree(dataSet , labels):    #所谓的类别信息就是划分后的几个类别,比如yes,no ;或者: 猫,狗,牛 等类别信息 ; 但是标签似乎和类别是相同的说法    classList = [ temp[-1] for temp in dataSet  ]    #如果只有一个类别,说明之前经过某个特征值划分后的数据集只有一个类别,直接返回该类别    if classList.count(classList[0]) == len(classList):        return classList[0]    #如果所有特征都用完,选择出现次数最多的类别    if 1 == len(classList[0]) :        return majorityCount(classList)    #选择能够带来最大信息增益的特征,并按照该特征值划分得到的子数据集 重复上述操作    bestFeature = chooseBestFeature(dataSet)    features = [ temp[bestFeature] for temp in dataSet ]    uniqueFratures = set(features)    bestFeatureLabel = labels[bestFeature]    decisionTree = {bestFeatureLabel : {} }    #需要删除最优特征对应的标签    del labels[bestFeature]    for value in uniqueFratures:        subLabels = labels[ : ]        subdatas = splitDataSet(dataSet , bestFeature , value)        decisionTree[bestFeatureLabel][value] = createTree(subdatas , subLabels)    return decisionTreedef decisionTree_test():    dataSet ,labels = createDataSet()    decisionTree = createTree(dataSet , labels)    print(decisionTree)decisionNode = dict(boxstyle='sawtooth' , fc='0.8')leafNode = dict(boxstyle="round4", fc='0.8')arrow_args = dict(arrowstyle="<-")def plotNode(nodeText , centerPoint , parentPoint , nodeType):    # 注解(文本,起始结点,坐标,xy文本,文本坐标,垂直,水平居中,矩形样式,箭头样式)    createPlot.ax1.annotate(nodeText , xy=parentPoint , xycoords="axes fraction", xytext=centerPoint , textcoords="axes fraction" ,                            va="center" , ha="center" , bbox=nodeType , arrowprops=arrow_args)def createPlot():    fig = plt.figure(1 , facecolor='white' )    fig.clf()    # subplot(行数,列数,编号) ,frameon是否显示网格    createPlot.ax1 = plt.subplot(111, frameon=False)    plotNode("决策结点" , (0.5 , 0.1) , (0.1 , 0.5) , decisionNode)    plotNode("叶节点" , (0.8,0.1) , (0.3 , 0.8) , leafNode )    plt.show()#获取叶节点个数,通过判断对应{key, val{}}中val中每个键对应的值如果是字典就递归累加;否则表明是孩子结点def getNumLeafs(myTree):    leafNum = 0    #python3.x keys()返回字典    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    for key , value in secondDict.items() :        if isinstance(value , dict):            leafNum += getNumLeafs(value)        else:            leafNum += 1    return leafNum#获取树的层数,不断累加当前层数,选取层数中大者返回def getTreeDepth(myTree):    maxDepth = 0    firstStr = list(myTree.keys())[0]    secondDict = myTree[firstStr]    depth = 0    for key , value in secondDict.items():        if isinstance(value ,dict):            depth = 1 + getTreeDepth(value)        #叶子结点高度为1        else:            depth = 1        if depth > maxDepth:            maxDepth = depth    return maxDepthdef getTreeDepth_test():    dataSet ,labels = createDataSet()    decisionTree = createTree(dataSet , labels)    #print(decisionTree)    leafNum = getNumLeafs(decisionTree)    depth = getTreeDepth(decisionTree)    print("leaf num: %d , depth: %d" % (leafNum , depth))#父子结点中间填充文本def plotMidText(centerPoint , parentPoint , textString):    xMid = (parentPoint[0] - centerPoint[0]) / 2.0 + centerPoint[0]    yMid = (parentPoint[1] - centerPoint[1]) / 2.0 + centerPoint[1]    createPlot.ax1.text(xMid , yMid , textString)#计算宽和高# 绘制决策树def plotTree(myTree,parentPt,nodeTxt):    # 定义并获得决策树的叶子结点数    numLeafs = getNumLeafs(myTree)    #depth =    getTreeDepth(myTree)    # 得到第一个特征    firstStr = list(myTree.keys())[0]    # 计算坐标,x坐标为当前树的叶子结点数目除以整个树的叶子结点数再除以2,y为起点    cntrPt = (plotTree.xOff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)    # 绘制中间结点,即决策树结点,也是当前树的根结点,这句话没感觉出有用来,注释掉照样建立决策树,理解浅陋了,理解错了这句话的意思,下面有说明    plotMidText(cntrPt, parentPt, nodeTxt)    # 绘制决策树结点    plotNode(firstStr,cntrPt,parentPt,decisionNode)    # 根据firstStr找到对应的值    secondDict = myTree[firstStr]    # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD    # 遍历secondDict    for key in secondDict.keys():        # 如果secondDict[key]为一棵子决策树,即字典        if type(secondDict[key]).__name__ == 'dict':            # 递归的绘制决策树            plotTree(secondDict[key],cntrPt,str(key))        # 若secondDict[key]为叶子结点        else:            # 计算叶子结点的横坐标            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW            # 绘制叶子结点            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode)            # 这句注释掉也不影响决策树的绘制,自己理解的浅陋了,这行代码是特征的值            plotMidText((plotTree.xOff,plotTree.yOff), cntrPt, str(key))    # 计算纵坐标    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalDdef createPlot(inTree):    # 定义一块画布(画布是自己的理解)    fig = plt.figure(1,facecolor='white')    # 清空画布    fig.clf()    # 定义横纵坐标轴,无内容    axprops = dict(xticks=[],yticks=[])    # 绘制图像,无边框,无坐标轴    createPlot.ax1 = plt.subplot(111,frameon=True,**axprops)    # plotTree.totalW保存的是树的宽    plotTree.totalW = float(getNumLeafs(inTree))    # plotTree.totalD保存的是树的高    plotTree.totalD = float(getTreeDepth(inTree))    # 决策树起始横坐标    plotTree.xOff = - 0.5 / plotTree.totalW #从0开始会偏右    #print(plotTree.xOff)    # 决策树的起始纵坐标    plotTree.yOff = 1.0    # 绘制决策树    plotTree(inTree,(0.5,1.0),'')    # 显示图像    plt.show()#使用决策树的分类函数:比较测试数据与决策树上的数值,递归执行过程直到进入叶子结点,将测试数据定义为叶子结点所属的类型def classify(inputTree , featLabels , testVec):    firstStr = list(inputTree.keys())[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    for key , value in secondDict.items():        #找到属性对应的值        if testVec[featIndex] == key:            if isinstance(value , dict):                classLabel = classify(value , featLabels , testVec)            else:                classLabel = value    return classLabel'''决策树的存储:构造决策树耗时,每次分类时调用已经构造好的决策树,使用pickle序列化对象。序列化对象:在磁盘保存    #必须以二进制形式保存    fw = open(fileName , "wb" )    # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化    pickle.dump(inputTree ,fw)    fw.close()    fr = open(fileName , "rb")    # pickle.load(file):从文件中读取字符串,重构为原来的python对象    return pickle.load(fr)'''def storeTree(inputTree , fileName):    import pickle    #必须以二进制形式保存    fw = open(fileName , "wb" )    # pickle.dump(obj , file, protocol) :件对象保存到文件中,pickle可以事先基本数据的序列和反序列化    pickle.dump(inputTree ,fw)    fw.close()def grabTree(fileName):    import pickle    fr = open(fileName , "rb")    # pickle.load(file):从文件中读取字符串,重构为原来的python对象    return pickle.load(fr)#鱼分类问题def fishClassify():    #calcShannonEnt_test()    #chooseBestFeature_test()    #decisionTree_test()    #createPlot()    #getTreeDepth_test()    dataSet ,labels = createDataSet()    copyLabels = labels[:]    #注意构建决策树会使得原来标签集发生改变,这里需要传入一个副本    decisionTree = createTree(dataSet , copyLabels)    #序列化保存    fileName = "classifierStorage.txt"    storeTree(decisionTree , fileName)    decisionTree = grabTree(fileName)    print(labels)    result = classify(decisionTree , labels , [1, 0])    print(result)    result = classify(decisionTree , labels , [1, 1])    print(result)    createPlot(decisionTree)#镜片分类问题def lenseClassify():    fr = open('lenses.txt')    lenses = [line.strip().split("\t") for line in fr.readlines()]    lensesLabels = ["age" , "prescript" , "astigmatic" , "tearRate"]    lenseTree = createTree(lenses , lensesLabels)    createPlot(lenseTree)if __name__ == "__main__":    fishClassify()    lenseClassify()


0 0
原创粉丝点击