《机器学习实战》——第3章代码详解(决策树)

来源:互联网 发布:交换机镜像端口配置 编辑:程序博客网 时间:2024/06/05 05:12

from math import log

import operator

 

def createDataSet():  # 创建数据集 

    dataSet = [[1,1,'yes'],  

               [1,1,'yes'],  

               [1,0,'no'],  

               [0,1,'no'],  

               [0,1,'no']]  

    labels=['no surfacing','flippers']  

    return dataSet,labels  

 

def CalcShannonEnt(dataSet): # 计算信息熵       

    numEntries = len(dataSet) #计算数据集的输入个数        

    labelCounts = {}   # 创建存储标签的元字典  #[]列表,{}元字典,()元组      

    for featVec in dataSet:  #对数据集dataSet中的每一行featVec进行循环遍历        

        currentLabels =featVec[-1]   # currentLabelsfeatVec的最后一个元素           

        if currentLabels not in labelCounts.keys():

                 # 如果标签currentLabels不在元字典对应的key中                

            labelCounts[currentLabels] = 0

                 #将标签currentLabels放到字典中作为key,并将值赋为0            

        labelCounts[currentLabels] += 1 # currentLabels对应的值加1      

    shannonEnt = 0.0  # 定义香农熵shannonEnt     

    for key in labelCounts:   # 遍历元字典labelCounts中的key,即标签

        prob = float(labelCounts[key])/numEntries   # 计算每一个标签出现的频率,即概率        

        shannonEnt -= prob*log(prob,2)  # 根据信息熵公式计算每个标签信息熵并累加到shannonEnt上     

return shannonEnt # 返回求得的整个标签对应的信息熵

 

def splitDataSet(dataSet,axis,value):  # 分割数据集

# dataSet数据集,axis是对应的要分割数据的列,value是要分割的列按哪个值分割,即找到含有该值的数据      

    retDataSet = []  # 定义要返回的数据集      

    for featVec in dataSet:  # 遍历数据集中的每个特征,即输入数据         

        if featVec[axis] == value:

                # 如果列标签对应的值为value,则将该条()数据加入到retDataSet中                

            reducedFeatVec = featVec[:axis]

                # featVec0-axis个数据,不包括axis,放到reducedFeatVec中               

            reducedFeatVec.extend(featVec[axis+1:])

                # featVecaxis+1到最后的数据,放到reducedFeatVec的后面  

            retDataSet.append(reducedFeatVec) # reducedFeatVec添加到分割后的数据集retDataSet中,同时reducedFeatVecretDataSet中没有了axis列的数据     

    return retDataSet # 返回分割后的数据集  

 

 

def chooseBestFeatureToSplit(dataSet):  # 选择使分割后信息增益最大的特征,即对应的列       

    numFeatures = len(dataSet[0]) - 1  # 获取特征的数目,从0开始,dataSet[0]是一条数据     

    baseEntropy = CalcShannonEnt(dataSet)  # 计算数据集当前的信息熵    

    bestInfoGain = 0.0   # 定义最大的信息增益     

    bestFeature = -1   # 定义分割后信息增益最大的特征  

    # 遍历特征,即所有的列,计算每一列分割后的信息增益,找出信息增益最大的列  

    for i in range(numFeatures):            

        featList = [example[i] for example in dataSet]# 取出第i列特征赋给featList  

        # 将特征对应的值放到一个集合中,使得特征列的数据具有唯一性  

        uniqueVals = set(featList)            

        newEntropy = 0.0  # 定义分割后的信息熵        

        for value in uniqueVals:   # 遍历特征列的所有值(值是唯一的,重复值已经合并),分割并计算信息增益             

          subDataSet = splitDataSet(dataSet, i, value)# 按照特征列的每个值进行数据集分割                

          prob = len(subDataSet) / float(len(dataSet))  # 计算分割后的每个子集的概率                # 计算分割后的子集的信息熵并相加,得到分割后的整个数据集的信息熵                

          newEntropy +=prob * CalcShannonEnt(subDataSet)         

       infoGain = baseEntropy - newEntropy    # 计算分割后的信息增益         

       if(infoGain > bestInfoGain):   # 如果分割后信息增益大于最好的信息增益             

            bestInfoGain = infoGain   # 将当前的分割的信息增益赋值为最好信息增益           

            bestFeature = i   # 分割的最好特征列赋为i     

    return bestFeature   # 返回分割后信息增益最大的特征列  

 

  

def majorityCnt(classList):  # 对类标签进行投票 ,找标签数目最多的标签   

    classCount = {}   # 定义标签元字典,key为标签,value为标签的数目     

    for vote in classList:   # 遍历所有标签        

        if vote not in classCount.keys(): #如果标签不在元字典对应的key中           

            classCount[vote] = 0     # 将标签放到字典中作为key,并将值赋为0        

        classCount[vote] += 1  # 对应标签的数目加1   

    # 对所有标签按数目排序  

    sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)      

    return sortedClassCount[0][0]  # 返回数目最多的标签  

   

 

def createTree(dataSet,labels):  # 创建决策树

    # dataSet的最后一列数据(标签)取出赋给classListclassList存储的是标签列  

    classList = [example[-1] for example in dataSet]       

    if classList.count(classList[0]) == len(classList):  # 判断是否所有的列的标签都一致          

        return classList[0]   # 直接返回标签列的第一个数据      

    if len(dataSet) == 1:  # 判断dataSet是否只有一条数据        

        return majorityCnt(classList)  # 返回标签列数据最多的标签       

    bestFeat = chooseBestFeatureToSplit(dataSet)   # 选择一个使数据集分割后最大的特征列的索引    

    bestFeatLabel = labels[bestFeat]   # 找到最好的标签       

    myTree = {bestFeatLabel:{}}  # 定义决策树,keybestFeatLabelvalue为空    

    del(labels[bestFeat])   # 删除labels[bestFeat]对应的元素      

    featValues = [example[bestFeat] for example in dataSet]  # 取出dataSetbestFeat列的所有值      

    uniqueVals = set(featValues) # 将特征对应的值放到一个集合中,使得特征列的数据具有唯一性       

    for value in uniqueVals:# 遍历uniqueVals中的值            

        subLabels = labels[:]  # 子标签subLabelslabels删除bestFeat标签后剩余的标签

        # myTreekeybestFeatLabel时的决策树  

        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat, value), subLabels)      

    return myTree  # 返回决策树  

 

 

  

def classify(inputTree,featLabels,testVec):  # 决策树分类函数  

    firstStr = inputTree.keys()[0]   # 得到树中的第一个特征      

    secondDict = inputTree[firstStr] # 得到第一个对应的值   

    featIndex = featLabels.index(firstStr)  # 得到树中第一个特征对应的索引    

    for key in secondDict.keys():  # 遍历树     

        if testVec[featIndex] == key:    # 如果在secondDict[key]中找到testVec[featIndex]             

            if type(secondDict[key]).__name__ == 'dict':  # 判断secondDict[key]是否为字典  

                # 若为字典,递归的寻找testVec  

                classLabel = classify(secondDict[key], featLabels, testVec)  

            else:    # secondDict[key]为标签值,则将secondDict[key]赋给classLabel

                classLabel = secondDict[key]   

    return classLabel  # 返回类标签

 

  

def storeTree(inputTree,filename):# 决策树的序列化        

    import pickle  # 导入pyton模块      

    fw = open(filename,'w') # 以写的方式打开文件      

    pickle.dump(inputTree,fw)  # 决策树序列化         

          

def grabTree(filename):  # 读取序列化的树

    import pickle  

    fr = open(filename)  # 导入pyton模块      

    return pickle.load(fr)  # 返回读到的树

 

 

 

 

 

 

import matplotlib.pyplot  as plt  

    

# 定义决策树决策结果的属性,用字典来定义  

# 下面的字典定义也可写作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}

 

decisionNode = dict(boxstyle="sawtooth",fc="0.8")  # boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细

leafNode = dict(boxstyle="round4",fc="0.8")  # 定义决策树的叶子结点的描述属性  

arrow_args = dict(arrowstyle="<-")   # 定义决策树的箭头属性

 

 

def plotNode(nodeTxt,centerPt,parentPt,nodeType):  # 绘制结点

    # annotate是关于一个数据点的文本  

# nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点  

    createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axesfraction',xytext=centerPt,textcoords='axesfraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args)  

 

def getNumLeafs(myTree):  # 获得决策树的叶子结点数目      

    numLeafs = 0   # 定义叶子结点数目   

    firstStr = myTree.keys()[0]    # 获得myTree的第一个键值,即第一个特征,分割的标签  

    secondDict = myTree[firstStr]  # 根据键值得到对应的值,即根据第一个特征分类的结果  

    for key in secondDict.keys():   # 遍历得到的secondDict          

        if type(secondDict[key]).__name__ == 'dict': #如果secondDict[key]为一个字典,即决策树结点              

            numLeafs += getNumLeafs(secondDict[key])   # 则递归的计算secondDict中的叶子结点数,并加到numLeafs上         

        else:                       # 如果secondDict[key]为叶子结点               

            numLeafs += 1      # 则将叶子结点数加1     

    return numLeafs   # 返回求的叶子结点数目  

 

 

def getTreeDepth(myTree):  # 获得决策树的深度  

    maxDepth = 0     # 定义树的深度  

    firstStr = myTree.keys()[0]   # 获得myTree的第一个键值,即第一个特征,分割的标签  

    secondDict = myTree[firstStr]  # 根据键值得到对应的值,即根据第一个特征分类的结果

    for key in secondDict.keys():   

        if type(secondDict[key]).__name__ == 'dict':   # 如果secondDict[key]为一个字典

            thisDepth = 1 + getTreeDepth(secondDict[key])  

               # 则当前树的深度等于1加上secondDict的深度,只有当前点为决策树点深度才会加1    

        else:                 # 如果secondDict[key]为叶子结点   

            thisDepth = 1    # 则将当前树的深度设为1

        if thisDepth > maxDepth:   # 如果当前树的深度比最大数的深度

            maxDepth = thisDepth  

    return maxDepth   # 返回树的深度

 

 

def plotMidText(cntrPt,parentPt,txtString): # 绘制中间文本   

    xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]   # 求中间点的横坐标

    yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]  # 求中间点的纵坐标

    createPlot.ax1.text(xMid,yMid,txtString)  # 绘制树结点  (函数)

 

 

def plotTree(myTree,parentPt,nodeTxt): # 绘制决策树  

    numLeafs = getNumLeafs(myTree)   # 定义并获得决策树的叶子结点数

    depth = getTreeDepth(myTree)  # 定义并获得决策树的深度

    firstStr = myTree.keys()[0]  # 得到第一个特征

    # 计算坐标,x坐标为当前树的叶子结点数目除以整个树的叶子结点数再除以2y为起点  

    cntrPt = (plotTree.xOff + (1.0 +float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)  

    # 绘制中间结点,即决策树结点,也是当前树的根结点,这句话没感觉出有用来,  

    plotMidText(cntrPt, parentPt, nodeTxt)        

    plotNode(firstStr,cntrPt,parentPt,decisionNode) # 绘制决策树结点      

    secondDict = myTree[firstStr]   # 根据firstStr找到对应的值     

    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD  # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点    

    for key in secondDict.keys():   # 遍历secondDict  

        if type(secondDict[key]).__name__ == 'dict': # 如果secondDict[key]为一棵子决策树,即字典             

            plotTree(secondDict[key],cntrPt,str(key)) # 递归的绘制决策树           

        else:     # secondDict[key]为叶子结点            

            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW   # 计算叶子结点的横坐标  

            plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt, leafNode)  #绘制叶子结点

            # 这句注释掉也不影响决策树的绘制,自己理解的浅陋了,这行代码是特征的值  

            plotMidText((plotTree.xOff,plotTree.yOff), cntrPt, str(key))  

    plotTree.yOff = plotTree.yOff +1.0/plotTree.totalD    # 计算纵坐标

 

def createPlot(inTree):      

    fig = plt.figure(1,facecolor='white')  # 定义一块画布(画布是自己的理解)    

    fig.clf()  # 清空画布     

    axprops = dict(xticks=[],yticks=[])  # 定义横纵坐标轴,无内容

    createPlot.ax1 = plt.subplot(111,frameon=True,**axprops) # 绘制图像,无边框,无坐标轴      

    plotTree.totalW = float(getNumLeafs(inTree))# plotTree.totalW保存的是树的宽        

    plotTree.totalD = float(getTreeDepth(inTree))  # plotTree.totalD保存的是树的高  

     

    plotTree.xOff = - 0.5 / plotTree.totalW #0开始会偏右  #决策树起始横坐标

    print  plotTree.xOff    

    plotTree.yOff = 1.0 # 决策树的起始纵坐标  

    plotTree(inTree,(0.5,1.0),'')  # 绘制决策树

    plt.show() # 显示图像

 

 

 

 

 

 

 

 

 

原创粉丝点击