决策树

来源:互联网 发布:淘宝店铺页头怎么换 编辑:程序博客网 时间:2024/06/17 01:59
from math import logFEATURE = ["不浮出水面是否可以生存", "是否有脚蹼"]def calculate_entropy(dataset):    set_num = len(dataset)    label_statc = {}    for item in dataset:        label = item[-1]        if label not in label_statc.keys():            label_statc[label] = 0        label_statc[label] += 1    entropy = 0    for label, num in label_statc.items():        prob = num / set_num        entropy -= prob * log(prob, 2)    return entropydef split_dataset(dataset, index, value):    new_dataset = []    for item in dataset:        if item[index] == value:            new_dataset.append(item)    return new_datasetdef determine_feature(dataset):    feature_num = len(dataset[0]) - 1    base_entropy = calculate_entropy(dataset)    best_info_gain = 0    best_feature_index = None    for f in range(feature_num):        feature_attrs = get_dataset_feature_attr(dataset, f)        cur_entropy = 0        for fa in feature_attrs:            fa_dataset = split_dataset(dataset, f, fa)            cur_entropy += len(fa_dataset) / len(dataset) * calculate_entropy(fa_dataset)        cur_info_gain = base_entropy - cur_entropy        if cur_info_gain > best_info_gain:            best_info_gain = cur_info_gain            best_feature_index = f    return best_feature_indexdef get_dataset_labels(dataset):    return set([item[-1] for item in dataset])def get_dataset_feature_attr(dataset, f):    return set([item[f] for item in dataset])def del_dataset_feature(dataset, f):    new_data_set = []    for item in dataset:        new_data_set.append([item[c] for c in item if c != f])    return new_data_setdef tree(dataset):    dataset_labels = get_dataset_labels(dataset)    if len(dataset_labels) == 1:        return list(dataset_labels)[0]    if len(dataset[0]) == 1:        return "maybe"    f = determine_feature(dataset)    feature = FEATURE[f]    cur_tree = {feature: {}}    feature_attr = get_dataset_feature_attr(dataset, f)    for fa in feature_attr:        new_dataset = split_dataset(dataset, f, fa)        cur_tree[feature][fa] = tree(new_dataset)    return cur_treerandom_dataset = [["是", "是", "是"], ["是", "是", "是"], ["是", "否", "否"], ["否", "是", "否"], ["否", "是", "否"]]decision_tree = tree(random_dataset)print(decision_tree)

这里写图片描述

原创粉丝点击