机器学习基础--决策树

来源:互联网 发布:现实中的美人知乎 编辑:程序博客网 时间:2024/05/16 17:00

决策树是很基础很经典的一个分类方法,基本上很多工业中很使用且常用的算法基础都是决策树,比如boost,GBDT,CART(分类回归树),我们后需会慢慢分析,决策时定义如下:
决策树:决策树模型是一种描述对实例进行分类的树形结构,其算法思想是分治法,由节点(node)和有向边组成,节点分两种类型,内部节点和叶子节点,内部节点标示一个特征或者树形,叶子节点表示一个类。
比如下面就是一个根据西瓜一些特征(来自周志华-机器学习)来分类好瓜坏瓜的决策树:
这里写图片描述

决策树难点在于划分选择,我们总是希望决策树的分支节点所包含的样本尽可能的属于同一类别,既节点的纯度高!
定义纯度最常用的是信息熵,信息熵越大标示混乱度越大,而越小则标示纯度越高
这里写图片描述
pk表示当前样本集合D中的第k类样本所占的比例,Ent(D)表示信息熵
同时我们基于信息熵定义信息增益
这里写图片描述
Dv表示属性a上取值为av的属性的样本数目,比如上面西瓜里面属性纹理里面纹理清晰的西瓜数目,我们可以看个例子
有这样一个西瓜数据集:
这里写图片描述

这个数据集中正例p1=8/17,反例占p2=9/17,我们计算根节点的信息熵
这里写图片描述

然后我们就可以计算各个属性的信息增益,例如我们计算“色泽”,包括(D1青绿,D2乌黑,D3浅白),青绿编号为{1,4,6,10,13,17},其中正例3/6,反例3/6计算如下
这里写图片描述

则“色泽”的信息增益:
这里写图片描述

我们同时可以计算出其他属性的信息增益,这里计算完之后发现“纹理”的增益为0.381,增益最大,所以我们选择先按纹理划分,最终的划分结果解释刚开始给出的样例决策树。

根据以上介绍我们可以知道决策树的优缺点:
优点

计算简单,易于理解,可解释性强;比较适合处理有缺失属性的样本;能够处理不相关的特征;在相对短的时间内能够对大型数据源做出可行且效果良好的结果。

缺点

容易发生过拟合(随机森林可以很大程度上减少过拟合);忽略了数据之间的相关性;对于那些各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征(只要是使用了信息增益,都有这个缺点,如RF,可用信息增益比解决)。

可以看出它很容易过拟合,原因是他会依照数据去不断划分子树,只要可划分就会一直划分下去,所以我们会想到用剪枝去避免过拟合。
剪枝分为预剪枝和后剪枝。
剪枝我们还需要一个验证数据集,通过训练数据集简历决策树,然后用验证数据集去验证,如果当前分支我们直接剪掉换成叶子节点,会不会提升验证集合的精度,例如上面决策树通过触感硬滑和软粘分好瓜和坏瓜,如果验证集合到这个子树下面不管硬滑还是软粘都是好瓜,那么我们可以直接剪掉换成好瓜叶子节点,其精度提升了50%
预剪枝就是从树建树开始,每一个节点评估剪掉是否提升精度,如果是则剪枝。
后剪枝则是建树完成之后,通过递归先看最下层叶子节点剪掉是否提升,依次向上递归。根本就是区别是遍历树的方式,我们可以猜出来相对来说后剪枝更好,因为预剪枝剪枝太早,有可能当前剪掉有提升,但是后续的划分会提高精度。后剪枝会比预剪枝保留更多分支,欠拟合风险小,泛化性能优于预剪枝,由于后剪枝是建树完成之后自底向上进行,所以训练时间更大

另一个问题就是对连续值的处理,因为实际问题中大多数值是连续值,对此最常用的方法是二分法,首先对当前属性连续值从小到大排序,然后相邻两个划分成一组(a-, a+),包括区间 [ ai, a(i+1) ),前闭后开区间。我们就可以得到n-1个划分集合:
这里写图片描述

用区间的中位点作为划分点,然后我们就可以求二分后的信息增益,同样选择信息增益最大的作为划分点:
这里写图片描述

下面是一个决策树代码,普通的ID3算法,连续的算法是C4.5:

from math import logimport operatordef createDataSet():    dataSet = [[1, 1, 'yes'],               [1, 1, 'yes'],               [1, 0, 'no'],               [0, 1, 'no'],               [0, 1, 'no']]    labels = ['no surfacing','flippers']    #change to discrete values    return dataSet, labelsdef calcShannonEnt(dataSet):    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet: #the the number of unique elements and their occurance        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key])/numEntries        shannonEnt -= prob * log(prob,2) #log base 2    return shannonEntdef splitDataSet(dataSet, axis, value):    retDataSet = []    for featVec in dataSet:        if featVec[axis] == value:            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting            reducedFeatVec.extend(featVec[axis+1:])            retDataSet.append(reducedFeatVec)    return retDataSetdef chooseBestFeatureToSplit(dataSet):    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels    baseEntropy = calcShannonEnt(dataSet)    bestInfoGain = 0.0; bestFeature = -1    for i in range(numFeatures):        #iterate over all the features        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature        uniqueVals = set(featList)       #get a set of unique values        newEntropy = 0.0        for value in uniqueVals:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet)/float(len(dataSet))            newEntropy += prob * calcShannonEnt(subDataSet)             infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy        if (infoGain > bestInfoGain):       #compare this to the best gain so far            bestInfoGain = infoGain         #if better than current best, set to best            bestFeature = i    return bestFeature                      #returns an integerdef majorityCnt(classList):    classCount={}    for vote in classList:        if vote not in classCount.keys(): classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createTree(dataSet,labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):         return classList[0]#stop splitting when all of the classes are equal    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet        return majorityCnt(classList)    bestFeat = chooseBestFeatureToSplit(dataSet)    bestFeatLabel = labels[bestFeat]    myTree = {bestFeatLabel:{}}    del(labels[bestFeat])    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for value in uniqueVals:        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)    return myTree                            def classify(inputTree,featLabels,testVec):    firstStr = inputTree.keys()[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr)    key = testVec[featIndex]    valueOfFeat = secondDict[key]    if isinstance(valueOfFeat, dict):         classLabel = classify(valueOfFeat, featLabels, testVec)    else: classLabel = valueOfFeat    return classLabeldef storeTree(inputTree,filename):    import pickle    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def grabTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)if __name__ == '__main__':    myDat, labels = createDataSet()    print calcShannonEnt(myDat) #计算信息熵    #myDat[0][-1]='maybe'    #print calcShannonEnt(myDat)    print myDat    print splitDataSet(myDat, 0, 1)     print chooseBestFeatureToSplit(myDat) #选择最优增益最大划分点    fr = open('lenses.txt')    lenses=[inst.strip().split('\t') for inst in fr.readlines()]    print "lenses", lenses    lenseLabels=['age', 'prescript', 'astigmatic', 'tearRate']    lensesTree = createTree(lenses, lenseLabels)    print lensesTree

隐形眼镜数据集合(lenses.txt):

young   myope   no  reduced no lensesyoung   myope   no  normal  softyoung   myope   yes reduced no lensesyoung   myope   yes normal  hardyoung   hyper   no  reduced no lensesyoung   hyper   no  normal  softyoung   hyper   yes reduced no lensesyoung   hyper   yes normal  hardpre myope   no  reduced no lensespre myope   no  normal  softpre myope   yes reduced no lensespre myope   yes normal  hardpre hyper   no  reduced no lensespre hyper   no  normal  softpre hyper   yes reduced no lensespre hyper   yes normal  no lensespresbyopic  myope   no  reduced no lensespresbyopic  myope   no  normal  no lensespresbyopic  myope   yes reduced no lensespresbyopic  myope   yes normal  hardpresbyopic  hyper   no  reduced no lensespresbyopic  hyper   no  normal  softpresbyopic  hyper   yes reduced no lensespresbyopic  hyper   yes normal  no lenses

通过数据建立如下决策树:
这里写图片描述

原创粉丝点击