决策树的python代码注解

来源:互联网 发布:销售提成计算软件 编辑:程序博客网 时间:2024/06/07 05:53

源码下载:http://download.csdn.net/detail/redhatforyou/9870168

1 包含的类

##1.1 data类
data类用来存储训练决策树的数据其中:
examples:表示所有的数据条目
attributes:表示数据的特征
attr_types:表示
class_index:表示

class data():
def __init__(self, classifier):
self.examples = []
self.attributes = []
self.attr_types = []
self.classifier = classifier
self.class_index = None

##1.2 treeNode类
treeNode类用来表示决策树中的节点其中不同属性表示如下所示:
is_leaf=True:默认为True用来表示是为叶子节点
classification=None:
attr_split=None:
attr_split_index=None:
self.attr_split_value = None:
self.parent = parent:
self.upper_child = None:
self.lower_child = None:
self.height = None:

class treeNode():
def __init__(self, is_leaf, classification, attr_split_index, attr_split_value, parent, upper_child, lower_child,
height):
self.is_leaf = True
self.classification = None
self.attr_split = None
self.attr_split_index = None
self.attr_split_value = None
self.parent = parent
self.upper_child = None
self.lower_child = None
self.height = None

2 函数的具体实现

2.1 def read_data(dataset, datafile, datatypes):函数

通过read_data()读取数据信息,读取后的数据信息存储在dataset里面。
datafile文件的格式如下所示:
这里写图片描述
根据上述的文件格式,第一行为整个文件中各个特征的名称,接下来为所有的数据条目

def read_data(dataset, datafile, datatypes):    print "Reading data..."    f = open(datafile)    original_file = f.read()    #split the data by lines    rowsplit_data = original_file.splitlines()    # split the data by ','    dataset.examples = [rows.split(',') for rows in rowsplit_data]    # list attributes    dataset.attributes = dataset.examples.pop(0)    # create array that indicates whether each attribute is a numerical value or not    attr_type = open(datatypes)    orig_file = attr_type.read()    dataset.attr_types = orig_file.split(',')

根据上述的图示和代码可知,在该函数中一共包含三个部分的功能:
(1)读取datafile信息并获取example
读取文件,每个数据项目以逗号隔开,通过逗号隔开的数据信息存储在example里面

f = open(datafile)original_file = f.read()#split the data by linesrowsplit_data = original_file.splitlines()# split the data by ','dataset.examples = [rows.split(',') for rows in rowsplit_data]

(2)读取datafile信息并获取attributes信息
获取标签特征的信息,获取首行标签特征的信息。

# list attributesdataset.attributes = dataset.examples.pop(0)

(3)读取datatypes文件
读取datatypes文件中的数据信息,其中datatypes文件的格式如下所示:
这里写图片描述

# create array that indicates whether each attribute is a numerical value or notattr_type = open(datatypes)orig_file = attr_type.read()dataset.attr_types = orig_file.split(',')

2.2 def preprocess2(dataset):函数

(1)获取class的标签的值并计算标签树木的大小
根据下述的代码,在dataset中用class_index存储标签的所存在的列的下标。通过class_mode来存储标签中出现最多的数据项个数。

#get the class value in the exampleclass_values = [example[dataset.class_index] for example in dataset.examples]class_mode = Counter(class_values)#find the most common one in the calss modeclass_mode = class_mode.most_common(1)[0][0]

(2)通过filter获取每一个attr_index的数据信息,并根据对应的数据信息找到对应的标签的值。

ex_0class = filter(lambda x: x[dataset.class_index] == '0', dataset.examples)values_0class = [example[attr_index] for example in ex_0class]#get the example data of 1 of attributes at the index of attr_indexex_1class = filter(lambda x: x[dataset.class_index] == '1', dataset.examples)values_1class = [example[attr_index] for example in ex_1class]

2.3 def compute_tree(dataset, parent_node, classifier):函数

(1)构建一棵决策树的节点,并进行初始化设置
判断该节点是否具有父亲节点,如果没有父亲节点则其高度设置为0,否则在其父亲节点的高度上加1。
判断该节点中数据树木的大小,如果在该节点中没有数据项或者数据项目都属于一个类,则返回该节点为一个叶子节点。

node = treeNode(True, None, None, None, parent_node, None, None, 0)#compute the node heightif (parent_node == None):    node.height = 0else:    node.height = node.parent.height + 1ones = one_count(dataset.examples, dataset.attributes, classifier)if (len(dataset.examples) == ones):    node.classification = 1    node.is_leaf = True    return nodeelif (ones == 0):    node.classification = 0    node.is_leaf = True    return nodeelse:    node.is_leaf = False

(2)初始化决策树的参数,并计算数据集的信息熵
attr_to_split:用来存储分裂的下标。
max_gain:用来存储最佳分裂的标签。
split_val:用来存储进行分裂的值。
min_gain:停止分裂的阈值。

attr_to_split = None  # The index of the attribute we will split onmax_gain = 0          # The gain given by the best attributesplit_val = Nonemin_gain = 0.01dataset_entropy = calc_dataset_entropy(dataset, classifier)

(3)获取对每一个标签标签的每一个可能分裂值进行分类判断并得到最佳的结果。

if (dataset.attributes[attr_index] != classifier):    local_max_gain = 0    local_split_val = None    attr_value_list = [example[attr_index] for example in                       dataset.examples]  # these are the values we can split on, now we must find the best one    attr_value_list = list(set(attr_value_list))  # remove duplicates from list of all attribute values

(4)在分裂值的数目大于100的情况下,将10个数据划分成一组得到新的分类值序列。

if (len(attr_value_list) > 100):    attr_value_list = sorted(attr_value_list)    total = len(attr_value_list)    ten_percentile = int(total / 10)    new_list = []    for x in range(1, 10):        new_list.append(attr_value_list[x * ten_percentile])    attr_value_list = new_list

(5)计算信息增益并将信息增益和这组标签中所有的分类值的信息增益结果想比较,如果具有较好的效果则替代,并记录分裂标签和分裂值。

for val in attr_value_list:    # calculate the gain if we split on this value    # if gain is greater than local_max_gain, save this gain and this value    local_gain = calc_gain(dataset, dataset_entropy, val,                           attr_index)  # calculate the gain if we split on this value    if (local_gain > local_max_gain):        local_max_gain = local_gain        local_split_val = val

(6)将最大的信息增益结果与全局的信息增益结果比较,如果有较好的分裂结果则替代,并记录分裂标签和分裂值。

if (local_max_gain > max_gain):    max_gain = local_max_gain    split_val = local_split_val    attr_to_split = attr_index

(7)对获得的信息增益进行判断,判断是否符合叶子节点的条件,如果符合则停止分裂返回一个叶子节点,并记录叶子该叶子节点属于哪种类型。

    # attr_to_split is now the best attribute according to our gain metric    if (split_val is None or attr_to_split is None):        print "Something went wrong. Couldn't find an attribute to split on or a split value."    elif (max_gain <= min_gain or node.height > 20):        node.is_leaf = True        node.classification = classify_leaf(dataset, classifier)        return node

(8)如果不是叶子节点则对节点的基本信息和子节点的信息进行记录,并分类该节点递归地调用分裂函数。

    node.attr_split_index = attr_to_split    node.attr_split = dataset.attributes[attr_to_split]    node.attr_split_value = split_val    # currently doing one split per node so only two datasets are created    upper_dataset = data(classifier)    lower_dataset = data(classifier)    upper_dataset.attributes = dataset.attributes    lower_dataset.attributes = dataset.attributes    upper_dataset.attr_types = dataset.attr_types    lower_dataset.attr_types = dataset.attr_types    for example in dataset.examples:        if (attr_to_split is not None and example[attr_to_split] >= split_val):            upper_dataset.examples.append(example)        elif (attr_to_split is not None):            lower_dataset.examples.append(example)    node.upper_child = compute_tree(upper_dataset, node, classifier)    node.lower_child = compute_tree(lower_dataset, node, classifier)    return node

2.4 def classify_leaf(dataset, classifier):函数

判断一个叶子节点属于哪种类型

def classify_leaf(dataset, classifier):    ones = one_count(dataset.examples, dataset.attributes, classifier)    total = len(dataset.examples)    zeroes = total - ones    if (ones >= zeroes):        return 1    else:        return 0

2.5 def calc_dataset_entropy(dataset, classifier):函数

在这里通过calc_dataset_entropy计算数据集合的信息熵。通过one_count计算节点中的数据项数。
在这里默认地对结果的处理中只有两种类型的数据结果。通过公式:

H(x)=E(I(x))=i=1np(xi)I(xi)=i=1np(xi)logbp(xi)

计算信息熵的值。

def calc_dataset_entropy(dataset, classifier):    ones = one_count(dataset.examples, dataset.attributes, classifier)    total_examples = len(dataset.examples);    entropy = 0    p = ones / total_examples    if (p != 0):        entropy += p * math.log(p, 2)    p = (total_examples - ones) / total_examples    if (p != 0):        entropy += p * math.log(p, 2)    entropy = -entropy    return entropy

2.6 def calc_gain(dataset, entropy, val, attr_index):函数

通过该函数计算信息增益,信息增益的计算方式如下:

g(D,A)=H(D)H(D|A)

上述不同的参数分别表示的含义如下:
(1)dataset:用来存储列表中所有的数据项目。
(2)entropy:用来传递分裂前的信息熵。
(3)val:用来传递分裂的值。
(4)attr_index:用来存储分裂标签的索引

def calc_gain(dataset, entropy, val, attr_index):    classifier = dataset.attributes[attr_index]    attr_entropy = 0    total_examples = len(dataset.examples);    gain_upper_dataset = data(classifier)    gain_lower_dataset = data(classifier)    gain_upper_dataset.attributes = dataset.attributes    gain_lower_dataset.attributes = dataset.attributes    gain_upper_dataset.attr_types = dataset.attr_types    gain_lower_dataset.attr_types = dataset.attr_types    for example in dataset.examples:        if (example[attr_index] >= val):            gain_upper_dataset.examples.append(example)        elif (example[attr_index] < val):            gain_lower_dataset.examples.append(example)    if (len(gain_upper_dataset.examples) == 0 or len(            gain_lower_dataset.examples) == 0):  # Splitting didn't actually split (we tried to split on the max or min of the attribute's range)        return -1    attr_entropy += calc_dataset_entropy(gain_upper_dataset, classifier) * len(        gain_upper_dataset.examples) / total_examples    attr_entropy += calc_dataset_entropy(gain_lower_dataset, classifier) * len(        gain_lower_dataset.examples) / total_examples    return entropy - attr_entropy

2.7 def one_count(instances, attributes, classifier):函数

计算instances中属于类1数据项的个数并返回,其中attributes表示特征,classifier表示标签所在的列。

def one_count(instances, attributes, classifier):    count = 0    class_index = None    # find index of classifier    for a in range(len(attributes)):        if attributes[a] == classifier:            class_index = a        else:            class_index = len(attributes) - 1    for i in instances:        if i[class_index] == "1":            count += 1    return count

2.8 def prune_tree(root, node, dataset, best_score):函数

传入的参数为:
(1)root:训练后的决策树根节点
(2)node:一个节点
(3)dataset:验证集
(4)best_score:验证集的验证结果。
在prune_tree中实现对生成的决策树的剪枝操作,在剪枝操作中,对每一个叶子节点进行剪枝并通过验证集去判断新生成的决策树在验证集上的正确率,如果大于原来的正确率则进行剪枝得到新的决策树,否则使用原有的决策树。

def prune_tree(root, node, dataset, best_score):    # if node is a leaf    if (node.is_leaf == True):        # get its classification        classification = node.classification        # run validate_tree on a tree with the nodes parent as a leaf with its classification        node.parent.is_leaf = True        node.parent.classification = node.classification        if (node.height < 20):            new_score = validate_tree(root, dataset)        else:            new_score = 0        # if its better, change it        if (new_score >= best_score):            return new_score        else:            node.parent.is_leaf = False            node.parent.classification = None            return best_score    # if its not a leaf    else:        # prune tree(node.upper_child)        new_score = prune_tree(root, node.upper_child, dataset, best_score)        # if its now a leaf, return        if (node.is_leaf == True):            return new_score        # prune tree(node.lower_child)        new_score = prune_tree(root, node.lower_child, dataset, new_score)        # if its now a leaf, return        if (node.is_leaf == True):            return new_score        return new_score

2.9 def validate_tree(node, dataset):函数

对数据集中的每一个examples进行验证,并根据返回的结果计数,最后计算分类的准确率。

def validate_tree(node, dataset):    total = len(dataset.examples)    correct = 0    for example in dataset.examples:        # validate example        correct += validate_example(node, example)    return correct / total

2.10 def validate_example(node, example):函数

对每一个example节点的数据通过决策树进行判断其所属的类,并判断划分结果是否正确,如果正确则返回1否则返回0

def validate_example(node, example):    if (node.is_leaf == True):        projected = node.classification        actual = int(example[-1])        if (projected == actual):            return 1        else:            return 0    value = example[node.attr_split_index]    if (value >= node.attr_split_value):        return validate_example(node.upper_child, example)    else:        return validate_example(node.lower_child, example)

2.11 def test_example(example, node, class_index):函数

功能:判断每一项数据的属于哪一个类。
传入的参数:
example:表示每一条数据项目。
node:表示该数据沿着决策树分裂路径所到的节点
class_index:表示分类标签的索引。

def test_example(example, node, class_index):    if (node.is_leaf == True):        return node.classification    else:        if (example[node.attr_split_index] >= node.attr_split_value):            return test_example(example, node.upper_child, class_index)        else:            return test_example(example, node.lower_child, class_index)

2.12 def print_tree(node):函数

输出决策树的结构

def print_tree(node):    if (node.is_leaf == True):        for x in range(node.height):            print "\t",        print "Classification: " + str(node.classification)        return    for x in range(node.height):        print "\t",    print "Split index: " + str(node.attr_split)    for x in range(node.height):        print "\t",    print "Split value: " + str(node.attr_split_value)    print_tree(node.upper_child)    print_tree(node.lower_child)

2.13 def print_disjunctive(node, dataset, dnf_string):函数

通过disjunctive normal form格式输出决策树。

def print_disjunctive(node, dataset, dnf_string):    if (node.parent == None):        dnf_string = "( "    if (node.is_leaf == True):        if (node.classification == 1):            dnf_string = dnf_string[:-3]            dnf_string += ") ^ "            print dnf_string,        else:            return    else:        upper = dnf_string + str(dataset.attributes[node.attr_split_index]) + " >= " + str(            node.attr_split_value) + " V "        print_disjunctive(node.upper_child, dataset, upper)        lower = dnf_string + str(dataset.attributes[node.attr_split_index]) + " < " + str(node.attr_split_value) + " V "        print_disjunctive(node.lower_child, dataset, lower)        return

3 运行一个决策树的实例

这里写图片描述
如上图为运行结果:运行代码如下所示,设置的运行参数为:../data/btrain.csv -v ../data/bvalidate.csv -p -t ../data/btest.csv

import sysimport astimport csvfrom com.DecisionTree.DecisionTree import *################################################### main function, organize data and execute functions based on input# need to account for missing data##################################################def main():    #get the parameter of the args    args = str(sys.argv)    args = ast.literal_eval(args)    print args    # get the length of the parameter    if (len(args) < 2):        print "You have input less than the minimum number of arguments. Go back and read README.txt and do it right next time!"    #judge the type of the file    elif (args[1][-4:] != ".csv"):        print "Your training file (second argument) must be a .csv!"    else:        datafile = args[1]        #instantiation a dataset class number        dataset = data("")        #judge the data type of the types        if ("-d" in args):            datatypes = args[args.index("-d") + 1]        else:            datatypes = '../data/datatypes.csv'        #read data from datafile        read_data(dataset, datafile, datatypes)        arg3 = args[2]        #choose the attributes as classifier        if (arg3 in dataset.attributes):            classifier = arg3        else:            classifier = dataset.attributes[-1]        #dataset.classifier = 'Winner'        dataset.classifier = classifier        # find index of classifier default the index is the end of the attribustes        for a in range(len(dataset.attributes)):            if dataset.attributes[a] == dataset.classifier:                dataset.class_index = a            else:                dataset.class_index = range(len(dataset.attributes))[-1]        unprocessed = copy.deepcopy(dataset)        #preprocess the data        preprocess2(dataset)        print "Computing tree..."        root = compute_tree(dataset, None, classifier)        if ("-s" in args):            print_disjunctive(root, dataset, "")            print "\n"        if ("-v" in args):            datavalidate = args[args.index("-v") + 1]            print "Validating tree..."            validateset = data(classifier)            read_data(validateset, datavalidate, datatypes)            for a in range(len(dataset.attributes)):                if validateset.attributes[a] == validateset.classifier:                    validateset.class_index = a                else:                    validateset.class_index = range(len(validateset.attributes))[-1]            preprocess2(validateset)            best_score = validate_tree(root, validateset)            all_ex_score = copy.deepcopy(best_score)            print "Initial (pre-pruning) validation set score: " + str(100 * best_score) + "%"        if ("-p" in args):            if ("-v" not in args):                print "Error: You must validate if you want to prune"            else:                post_prune_accuracy = 100 * prune_tree(root, root, validateset, best_score)                print "Post-pruning score on validation set: " + str(post_prune_accuracy) + "%"        if ("-t" in args):            datatest = args[args.index("-t") + 1]            testset = data(classifier)            read_data(testset, datatest, datatypes)            for a in range(len(dataset.attributes)):                if testset.attributes[a] == testset.classifier:                    testset.class_index = a                else:                    testset.class_index = range(len(testset.attributes))[-1]            print "Testing model on " + str(datatest)            for example in testset.examples:                example[testset.class_index] = '0'            testset.examples[0][testset.class_index] = '1'            testset.examples[1][testset.class_index] = '1'            testset.examples[2][testset.class_index] = '?'            preprocess2(testset)            b = open('results.csv', 'w')            a = csv.writer(b)            for example in testset.examples:                example[testset.class_index] = test_example(example, root, testset.class_index)            saveset = testset            saveset.examples = [saveset.attributes] + saveset.examples            a.writerows(saveset.examples)            b.close()            print "Testing complete. Results outputted to results.csv"if __name__ == '__main__':    main()
原创粉丝点击