决策树算法(六)——构建决策树

来源:互联网 发布:sql developer打不开 编辑:程序博客网 时间:2024/05/16 06:12

写在前面的话

我是花痴,我只喜欢长得好看的,恩,我很肤浅,但是没办法,我就是喜欢长得帅的身材好的,我就是一俗人.我俗我开心.

递归构建决策树

之前我们已经学习了怎么根据信息论的方法,把一个数据集从杂乱无章的数据集中划分出来,我们使用信息论来构建决策树一级一级分类的方法就是一个递归的过程。

它的工作原理如下:

  • 得到原始数据集,然后基于最好的属性值划分数据集。每一次划分数据集,我们都要消耗一个特征,根据某个特征将某些性质相同的元素剥离出来
  • 划分数据的时候我们根据香农熵,计算信息增益之后找到最好的属性值进行数据的划分。
  • 由于特征值可能有多于两个的,因此可能存在大于两个分支的数据集划分
  • 第一次划分数据将向下传递到树分支的下一个节点,在这个节点上,我们可以再次划分数据集,因此我们可以采用递归的原则来处理数据集。

我们都知道递归必须要有一个终止条件,1)如果程序已经遍历完了所有的特征属性,2)或者每个分支下的所有实例都具有相同的分类,我们得到一个叶子节点或者终止块.这个就是我们递归的终止条件.

出现1这种情况的特殊情况就是我们之前在决策树算法(五)——处理一些特殊的分类 这篇文中已经详细的分析过了.当已经遍历完所有的特征属性但是任然还有一些类别灭有找出,那么我们就根据选举投票的方法来进行分类.

当然对于第一个结束条件算法可以终止,我们还可以设置算法可以划分的最大分组的数目.

创建决策树代码

下面我们来构建决策树的代码,使用递归来进行.
我们还是打开我们之前的文件trees.py, 在这个文件中添加如下的代码:

def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len (classList):        return classList[0]    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeatLabel])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)    return myTree

之后我们来分析一下这段代码:

还是用我们之前几章的那个数据集:
这里写图片描述
在这里任然需要注意我们的数据集是前面每一项都是特征值,最后一项是我们的类别信息.如下所示

dataSet = [[1,1,'yes'],              [1,1,'yes'],              [1,0,'no'],              [0,1,'no'],              [0,1,'no']]
#这里的第一条语句就是获得dataSet中的所有数据的类别:classList = [example[-1] for example in dataSet]#这种写法是python语法的一个特色,简单明了快捷随意.就是喜欢python这么随性,哈哈,像我.# example中每次取出的是dataSet中的一个元素,e.g. [0,1,'no']#example[-1] 就是每个元素的最后一列.

我们来看下执行结果:
这里写图片描述

之后的两个if条件是递归终止条件.

if classList.count(classList[0]) == len (classList):        return classList[0]# 这个条件语句是表示所有的数据都已经划分完成,每个类别已经完全相同#这样递归可以结束#count()函数中接受一个参数,表示的是这个参数在某个序列中出现的次数#如果这个classList中的元素完全相同,那么这个参数的count(classList[0])应该是等于这个List的长度的.

 if len(dataSet[0]) == 1:        return majorityCnt(classList)        # 第二个递归条件表示的只剩最后的类别信息的数据集.        #因为决策树算法每做一次信息的划分,都会消耗一个特征,当特征        #消耗完之后还有类别不同那么我们就需要投票表决了

这里写图片描述

看这张图应该很清楚了.

注意 :上面这两个条件语句都是我们递归结束的条件.

bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    #之后我们调用chooseBestFeatureToSplit函数

chooseBestFeatureToSplit函数的原型如下(我们在之前已经讲过):

def chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0])-1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain =0.0    bestFeature = -1    for i in range(numFeatures):        featList = [sample[i] for sample in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet,i,value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy        if(infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFeature = i    return bestFeature

这个函数 返回的是最好的特征值,通过计算最大信息增益获得的.

# bestFeat 存储的最好的特征的下标.它和我们的label是一一对应的   bestFeatLabel = labels[bestFeat]

这里写图片描述

在这里我们可以看出,我们的数据集有两个特征,就是no surfing 和 flippers . 每个数据集的第一列表示的是有还是不需要no surfing, 第二列表示的有没有flippers. 1表示有,2表示没有.


这里写图片描述

我们第一次调用chooseBestFeatureToSplit函数,结果告诉我们选择第一个特征比较好

    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat] # bestFeatLabel中存储了最佳特征的标签    myTree = {bestFeatLabel:{}} # 构建数据字典    del(labels[bestFeatLabel])# 删除最佳特征值
#找出最佳特征向量对应的所有特征值featValues = [example[bestFeat] for example in dataSet]# 除去重复的特征值    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)        #递归调用构建决策树

我们来测试一下我们的实验结果.

这里写图片描述

mytree 包含了很多代表结构信息的嵌套字典. 在代码中也可以看到我们实际上是用一个数据字典来构建我们的决策树.
第一个 no surfing 是第一个划分数据集的特征名称,在其下面有分为两类,特征是0的不是鱼类,是1的有被继续划分了.

这里写图片描述


到这里我们决策树算法算是讲完了,我们贴出整个分类的完整代码.

#!/usr/bin/env python# coding=utf-8# author: chicho# running: python trees.py# filename : trees.pyfrom math import logimport operatordef createDataSet():    dataSet = [[1,1,'yes'],              [1,1,'yes'],              [1,0,'no'],              [0,1,'no'],              [0,1,'no']]    labels = ['no surfacing','flippers']    return dataSet, labelsdef calcShannonEnt(dataSet):    countDataSet = len(dataSet)    labelCounts={}    for featVec in dataSet:        currentLabel=featVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    print labelCounts    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/countDataSet        shannonEnt -= prob * log(prob,2)    return shannonEntdef splitDataSet(dataSet,axis,value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reduceFeatVec = featVec[:axis]            reduceFeatVec.extend(featVec[axis+1:])            retDataSet.append(reduceFeatVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0])-1    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain =0.0    bestFeature = -1    for i in range(numFeatures):        featList = [sample[i] for sample in dataSet]        uniqueVals = set(featList)        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet,i,value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)        infoGain = baseEntropy - newEntropy        if(infoGain > bestInfoGain):            bestInfoGain = infoGain            bestFeature = i    return bestFeaturedef majorityCnt(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys():            classCount[vote] = 0        classCount[vote] += 1        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)    return sortedClassCount[0][0]def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len (classList):        return classList[0]    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeatLabel])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLabels)    return myTree



写在后面的话

你必须非常努力,才可以看起来毫不费力
加油~~~
要么就不做,要做就做最好

1 0
原创粉丝点击