机器学习实战---决策树

来源:互联网 发布:excel数据清洗 编辑:程序博客网 时间:2024/05/17 03:34

一、决策树的原理

参考:http://blog.csdn.net/suipingsp/article/details/41927247

1.1基本概念
信息熵:一个数据集信息量的度量
信息增益:在划分数据集之前和之后,信息熵的变化
我们假定每一个特征对于分类的重要性是相同的。
1.2决策树原理
a、决策树学习是以实例为基础的归纳学习。
b、决策树是一种树型结构,其中每个内部结点表示在一个属性上的测试,每个分支代表一个测试输出,每个叶结点代表一种类别。
c、决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一棵熵值下降最快的树,到叶子节点处的熵值为零(或者最低,因为有的时候特征都用完了,但是还存在不同的分类),此时每个叶节点中的实例都属于同一类。
d、我们训练的决策树参数就是各个特征的信息增益,然后根据信息增益的大小按照贪心策略对数据集划分。在测试决策树的时候,就是看测试集的数据属于哪个划分。
注:
决策树的关键是在当前状态下,选择哪个属性作为分类的依据。根据目标函数不同,有以下三种算法:

ID3算法:用信息增益最大的特征来划分数据集。
C4.5:信息增益率 g r (D,A) = g(D,A) / H(A)
CART:基尼指数

二、算法优缺点

(刚开始学,还是不理解啊!!后面加深理解了一定要补充一下!!!)
优点:
1、计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据
2、在学习的过程中,不需要使用者了解过多背景知识,只需要对训练实例进行较好的标注,就能够进行学习。从一类无序、无规则的事物(概念)中推理出决策树表示的分类规则。
缺点:决策树对训练属于有很好的分类能力,但对未知的测试数据未必有好的分类能力,泛化能力弱,即可能发生过拟合现象。(解决办法:剪枝,去掉一些信息增益少的属性;随机森林)
使用数据类型:数值型和标称型

三、Python实现

《机器学习实战》这一章用matplotlib画的图太蛋疼了,就不画了。
ID3算法:

# -*- coding:utf-8 -*-from math import logimport operator# 构建数据集def createDataSet():    dataSet = [        [1, 1, 'yes'],        [1, 1, 'yes'],        [1, 0, 'no'],        [0, 1, 'no'],        [0, 1, 'no']    ]    labels = ['no surfacing', 'flippers']    return dataSet, labels# 计算信息熵def calcShannonEnt(dataSet):    numEntries = len(dataSet)    labelCounts = {}    for featVec in dataSet:        currentLabel = featVec[-1]        if currentLabel not in labelCounts.keys():            labelCounts[currentLabel] = 0        labelCounts[currentLabel] += 1    shannonEnt = 0.0    for key in labelCounts:        prob = float(labelCounts[key]) / numEntries        shannonEnt -= prob * log(prob, 2)    return shannonEnt# 划分数据集:dataSet--数据集 axis--根据数据集的哪个轴(特征)来划分 value--根据这个特征的哪个值来划分def splitDataSet(dataSet, axis, value):    retDataSet = []  # 保存切割后的数据集    for featVec in dataSet:        if featVec[axis] == value:            reduceFeatVec = featVec[:axis]  # 获取这个轴左侧的数据            reduceFeatVec.extend(featVec[axis + 1:])  # 获取这个轴右侧的数据(包含最右侧的标签)            retDataSet.append(reduceFeatVec)    return retDataSet# 选择信息增益最大的特征,返回它的索引值def chooseBestFeature2Split(dataSet):    # 获得特征的数量(列数-1)    numFeatures = len(dataSet[0]) - 1    # 数据集的信息熵    baseEntropy = calcShannonEnt(dataSet)    # 最大信息增益    bestInfoGain = 0.0    # 产生最大信息增益对应的特征    bestFeature = -1    for i in range(numFeatures):        # 生成这一特征值(列)的list        featLIst = [example[i] for example in dataSet]        # 装换成set,去重        valueList = set(featLIst)        # 子集的信息熵        newEntropy = 0.0        for value in valueList:            subDataSet = splitDataSet(dataSet, i, value)            prob = len(subDataSet) / float(len(dataSet))            newEntropy = prob * calcShannonEnt(subDataSet)        # 信息增益        infoGain = baseEntropy - newEntropy        if infoGain > bestInfoGain:            bestInfoGain = infoGain            bestFeature = i    return bestFeature# 如果最后属性都遍历了一遍,但还是有类标签不唯一的情况,那么就进行投票,取多数的标签为叶子节点的标签# classList--类标签列表def majorityCnt(classList):    classCount = {}    for vote in classList:        if vote not in classCount.keys():            classCount[vote] = 0        classCount[vote] += 1    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)    return sortedClassCount[0][0]def createTree(dataSet, labels):    classList = [example[-1] for example in dataSet]    if classList.count(classList[0]) == len(classList):        return classList[0]    if len(dataSet[0]) == 1:        return majorityCnt(classList)    bestFeat = chooseBestFeature2Split(dataSet)    bestFeatLabel = labels[bestFeat]    del (labels[bestFeat])    # 用字典构建决策树    myTree = {bestFeatLabel: {}}    # 获取bestFeat的那一列的值    featValues = [example[bestFeat] for example in dataSet]    uniqueVals = set(featValues)    for val in uniqueVals:        sublabels = labels[:]        myTree[bestFeatLabel][val] = createTree(splitDataSet(dataSet, bestFeat, val), sublabels)    return myTreeif __name__ == '__main__':    dataSet, labels = createDataSet()    # shannonEnt = calcShannonEnt(dataSet)    # print '信息熵:%f' % shannonEnt    # retDataSet = splitDataSet(dataSet, 0, 1)    # for i in retDataSet:    #     print i    # print chooseBestFeature2Split(dataSet)    myTree = createTree(dataSet, labels)    print myTree
原创粉丝点击