决策树的算法描述和源码分析

来源:互联网 发布:人工智能伏羲觉醒种子 编辑:程序博客网 时间:2024/06/05 04:13

前言:

决策树是一种分类算法,简单而有效。它与KNN不同,KNN可以简单地得到结果,但是无法理解数据本身的含义。决策树的优势在于可以以树状的形式理解数据。现在我们希望可以判断邮件箱中的邮件属性,即是必须要处理的,还是属于垃圾邮件。决策树的做法如图一所示
这里写图片描述
我们将根据邮件数据中的两个属性进行判断,首先判断邮件域名地址,如果结果为是,则此封邮件属于无聊时需要阅读的邮件,如果结果为否,则继续判断数据的第二个属性即是否包含单词”曲棍球”,若结果为是,则此封邮件为需要及时处理的朋友邮件,为否则属于无需阅读的垃圾邮件。
可以看到得到的决策树简洁明了,如果配合adaboost、以及级联使用,可以更进一步地改善分类效果,决策树也是处理数据中非常常用的一种算法。决策树的算法有很多种,这里主要讲的是ID3算法,接下来将是算法描述和源码分析部分。

算法描述

划分数据集的大原则是:将无序的数据集变得更加有序。基于此原则,我们首先得确定何为有序数据,何为无序数据。
这里引入一个概念,香农熵或者简称为熵:它是由天才科学家香农提出的,表示随机变量不确定的度量。
计算公式如下:
这里写图片描述
其中这里写图片描述它表示了待分类的事物可能划分在多个分类中,则信息xi的定义即为l(xi),其中p(xi)是选择该分类的概率。很显然可以看到,熵即为信息的期望。当各分类的概率相同时,熵最大,当只有一种可能时,熵最小,此时熵为0。
根据我们划分数据的原则,我们希望选择特征的时候能把数据集的熵尽可能降低,进而使数据变得有序,更好地处理。根据这种自然而然的想法, 就有了下个定义,信息增益:
这里写图片描述
看起来经验条件熵非常复杂, 其实很简单,就是根据特征将一个大的数据集划分为几个小的数据集,然后分别计算小数据集的熵,然后取均值。

ID3算法的核心是在决策树各个结点上应用信息增益准则选择特征,递归地构建决策树。具体方法是:从根结点开始,对结点计算所有可能的特征的信息增益,选择信息增益最大的特征作为结点的特征,由该特征的不同取值建立子结点:再对子结点递归地调用以上方法,构建决策树,最后所得到的决策树如上图一所示。

得到决策树后就可以对新来的数据进行分类了,具体做法是,从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。

源码分析

对一个算法描述,难免由于不同人的不同理解产生误差,所以我个人认为了解一个算法的最好途径是通过代码——这种不会产生误差的方式。因为我已经在代码中加以了注释,就不再对各个部分进行讲解了,如果有问题欢迎讨论O(∩_∩)O~。

matlab代码

function readtextreadname='F:\机器学习实战\机器学习实战\源码+数据\Ch03\lenses.txt';text=textread(readname,'%s','delimiter','');%一行行地读取数据num=zeros(size(text,1),5);%开辟空间存储特征for i=1:size(text,1)text{i}=regexp(text{i}, '\w+','match');%采用正则表达式解析字符串if(size(text{i},2))==6    text{i}{5}=strcat(text{i}{5},text{i}{6});%因为最后一个特征有no lenses、soft、hard                                             %为了方便处理将第一种情况的两个字符串合在一起    text{i}(6)=[];end% 因为matlab没有类似python中的set操作,无法删除字符数组中重复出现的字符串,所以只好事先将字符转为数字switch text{i}{1}    case 'young'         num(i,1)=1;    case 'pre'           num(i,1)=2;    case 'presbyopic'         num(i,1)=3;endswitch text{i}{2}    case 'myope'         num(i,2)=1;    case 'hyper'           num(i,2)=2;endswitch text{i}{3}    case 'no'         num(i,3)=1;    case 'yes'           num(i,3)=2;endswitch text{i}{4}    case 'reduced'         num(i,4)=1;    case 'normal'           num(i,4)=2;endswitch text{i}{5}    case 'nolenses'         num(i,5)=1;    case 'hard'           num(i,5)=2;    case 'soft'           num(i,5)=3;endendlabel={'age','prescript','astigmatic','tearRate'};mytree=CreateTree(num,label);function  mytree=CreateTree(dataset,label)% function :构建决策树% Input    :dataset :数据集%           label:类别标签% Output   :输出记录决策树结点的结构体mytree.FatherNode='root';mytree.property='stem';databest=ChooseBestFeatureToSplit(dataset);mytree.NodeName=label{databest};Feature=unique(dataset(:,databest));FatherNode=label{databest};label(databest)=[];for i=1:length(Feature)    mytree=[CreateChildTree(splitDataSet(dataset,databest,Feature(i)),i,FatherNode,label);mytree];endfunction mytree=CreateChildTree(dataset,property,FatherNode,label)mytree.FatherNode=FatherNode;mytree.property=property;if length(unique(dataset(:,end)))==1    %如果剩下的所有样本皆为同一类则跳出    mytree.NodeName=unique(dataset(:,end));    return endif size(dataset,2)==1           %分割完所有的特征,则选取剩下样本中类别数最多的一类返回    mytree.NodeName=MajorClass(dataset);    returnenddatabest=ChooseBestFeatureToSplit(dataset);mytree.NodeName=label{databest};Feature=unique(dataset(:,databest));FatherNode=label{databest};label(databest)=[];for i=1:length(Feature)    mytree=[CreateChildTree(splitDataSet(dataset,databest,Feature(i)),i,FatherNode,label);mytree];endfunction class=MajorClass(classlist)% function : 选择类别中数量最多的类别返回% Input    : 给定的类别列表%% Output   :返回最后的类别 tabClasslist=tabulate(classlist);[~,index]=max(tabClasslist(:,3));class=tabClasslist(1,index);function BestFeature=ChooseBestFeatureToSplit(dataset)% funciton: 选择最好的特征来分割样本% Input   : 一个表示样本的m*N+1矩阵,前N列为特征,最后一列为类别% % Output  : 返回最好的特征索引BestFeature=0;BestEnt=1/eps;for i=1:size(dataset,2)-1    Feature=dataset(:,i);    FeatureClass=unique(Feature);    newEnt=0;    for j=1:length(FeatureClass)       subData=splitDataSet(dataset,i,FeatureClass(j));       prob=length(subData)/length(Feature);       newEnt=newEnt+prob*CalShannonEnt(subData(:,end));    end    if newEnt<BestEnt        BestEnt=newEnt;        BestFeature=i;    endendfunction  Ent=CalShannonEnt(datavector)% function: 计算香农熵% Input    datavector: 代表类别的向量% % output   Ent:返回计算的熵值% tabclass=tabulate(datavector);tabclass(:,3)=tabclass(:,3)/100;tabclass(tabclass(:,2)==0,:)=[];Ent=-sum(tabclass(:,3).*log2(tabclass(:,3)));function reData=splitDataSet(dataset,axis,value)% function: 根据给定的特征划分样本% Input   dataset:给定的样本集,为一个n*p矩阵%         axis   :给定的划分特征%         value  :需符合的条件% Output  reData :根据给定条件所得到的划分index=find(dataset(:,axis)==value);dataset(:,axis)=[];reData=dataset(index,:);

Python代码

'''Created on Oct 12, 2010Decision Tree Source Code for Machine Learning in Action Ch. 3@author: Peter Harrington'''from math import logimport operatordef createDataSet():    #创建新的数据集    dataSet = [[1, 1, 'yes'],               [1, 1, 'yes'],               [1, 0, 'no'],               [0, 1, 'no'],               [0, 1, 'no']]    labels = ['no surfacing','flippers']    #change to discrete values    return dataSet, labelsdef calcShannonEnt(dataSet):    #计算香农熵    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet: #the the number of unique elements and their occurance        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries #统计每一个类出现的概率        shannonEnt -= prob * log(prob,2) #根据公式计算    return shannonEntdef splitDataSet(dataSet, axis, value):    #划分数据集    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:#把处于同一值下的所有样本取出            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting            reducedFeatVec.extend(featVec[axis+1:])            retDataSet.append(reducedFeatVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):    #根据信息增益准则,选择最好的特征进行分割    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0; bestFeature = -1    #计算以不同特征的不同值进行分割所得到的信息增益,选最大的那个进行分割    for i in range(numFeatures):        #iterate over all the features        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature        uniqueVals = set(featList)       #get a set of unique values        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)             infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy        if (infoGain > bestInfoGain):       #compare this to the best gain so far            bestInfoGain = infoGain         #if better than current best, set to best            bestFeature = i    return bestFeature                      #returns an integerdef majorityCnt(classList):    #如果特征筛选完毕,则选择所余样本中类别数最大的一个作为叶子节点的类别    classCount={}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createTree(dataSet,labels):    #构造字典树    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):         return classList[0]#stop splitting when all of the classes are equal    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)#得到最好的分类特征    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)    return myTree                            def classify(inputTree,featLabels,testVec):    #进行分类    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    key = testVec[featIndex]    valueOfFeat = secondDict[key]    if isinstance(valueOfFeat, dict):         #反复迭代,直到抵达叶子节点        classLabel = classify(valueOfFeat, featLabels, testVec)    else: classLabel = valueOfFeat    return classLabeldef GetData():    #获取text文本中的数据    fr=open('lenses.txt')    lenses=[rows.strip().split('\t') for rows in fr]    lenseslabel=['age','prescript','astigmatic','tearRate']    return lenses,lenseslabel'''def storeTree(inputTree,filename):    import pickle    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)'''myData,label=createDataSet()labeltemp=label[:]lenses,lenseslabel=GetData()tree=createTree(lenses,lenseslabel)print tree

包括数据集的完整链接,点这

1 0