决策树 学习

来源:互联网 发布:origin无法连接网络 编辑:程序博客网 时间:2024/05/21 10:23

特点

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型。

代码示例

# -*- coding: utf-8 -*-# __author__ = 'wangbowen'from math import logimport operatordef calcShannonEnt(data_set):    """ 计算给定数据集的熵 """    num_entries = len(data_set)    label_count = {}    # 为所有可能的分类创建字典    for val in data_set:        curr_label = val[-1]        if curr_label not in label_count:            label_count[curr_label] = 1        else:            label_count[curr_label] += 1    shannon_ent = 0.0    for key in label_count:        prob = float(label_count[key]) / num_entries        shannon_ent -= prob * log(prob, 2)    return shannon_entdef createDataSet():    data_set = [[1, 1, 'yes'],                [1, 1, 'yes'],                [1, 0, 'no'],                [0, 1, 'no'],                [0, 1, 'no']]    labels = ['no surfacing', 'flippers']    return data_set, labelsdef splitDataSet(data_set, axis, value):    """ 按照给定的特征划分数据集"""    # 创建新的list对象    # PS:Python语言在函数中传递的是列表的引用,在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期。    ret_data_set = []    for val in data_set:        # 将符合要求的元素抽取出来        if val[axis] == value:            reduced_val = val[:axis]            reduced_val.extend(val[axis + 1:])            ret_data_set.append(reduced_val)    return ret_data_setdef chooseBestSplit(data_set):    """ 选择最好的数据集划分方式 """    num = len(data_set[0]) - 1  # 最后一列是记录label    base_ent = calcShannonEnt(data_set)  # 原始的香农熵    best_gain = 0.0    best_feature = -1    for i in range(num):        # 创建唯一的分类标签列表        feat_list = [x[i] for x in data_set]        uniq_vals = set(feat_list)        new_ent = 0.0        # 对每个特征划分一次数据集,然后计算新数据集的新熵值        # 计算每种划分方式的信息熵        for val in uniq_vals:            sub_data_set = splitDataSet(data_set, i, val)            prob = len(sub_data_set) / float(len(data_set))            # 对所有唯一特征的熵求和            new_ent += prob*calcShannonEnt(sub_data_set)        info_gain = base_ent - new_ent  # 信息增益是熵的减少        if info_gain > best_gain:            best_gain = info_gain            best_feature = i    return best_featuredef mayorityCnt(class_list):    """ 返回出现次数最多的分类名称 """    class_count = {}    for vote in class_list:        if vote not in class_count:            class_count[vote] = 0        else:            class_count[vote] += 1    sorted_class_count = sorted(class_count.iteritems(),                                key=operator.itemgetter(1),                                reverse=True)    return sorted_class_count[0][0]def createTree(data_set, labels):    print 'create tree', data_set, labels    class_list = [x[-1] for x in data_set]    # 类别完全相同 则停止继续划分    if class_list.count(class_list[0]) == len(class_list):        return class_list[0]    # 遍历完所有特征时,返回出现最多的    if len(data_set[0]) == 1:        return mayorityCnt(class_list)    best_feat = chooseBestSplit(data_set)    best_label = labels[best_feat]    my_tree = {best_label: {}}    del(labels[best_feat])    feat_values = [x[best_feat] for x in data_set]    uniq_values = set(feat_values)    for val in uniq_values:        sub_labels = labels[:]        my_tree[best_label][val] = createTree(splitDataSet(data_set, best_feat, val),                                              sub_labels)    return my_treeif __name__ == '__main__':    data_set, labels = createDataSet()    tree = createTree(data_set, labels)    print tree