机器学习值决策树算法(下)-ID3实现

来源:互联网 发布:芒果tv有mac版吗 编辑:程序博客网 时间:2024/06/11 07:09

  • 决策树展示
  • 决策树分类
  • 决策树存储
  • 总结
  • 参考

1.决策树展示

决策树展示
上一篇文章,我们介绍了决策树的基本概念,信息基本概念以及如何通过选择最优分类生成决策树,本篇文章首先介绍如何根据决策树通过matplotlib库来实现决策树图的展现。
matplotlib介绍
matplotlib是一个相对来说比较庞大的库,常用于机器学习中,用于直观的展示数据,例如本次的决策树展示,本章中介绍如何用该库创建树枝节点,以及如何展示文本信息。
from matplotlib import pyplot as pltclass NodePlot(object):    def __init__(self):        self.ax = None        self.decision_node = {'boxstyle': 'sawtooth', 'fc': '0.8'} # 决策节点类型        self.leaf_node = {'boxstyle': 'round4', 'fc': '0.8'} # 叶子节点类型        self.arrow_args = {'arrowstyle': '<-'} # 箭头类型        self.fig = None    def plot_node(self, node_text, center_pt, parent_pt, node_type):        self.ax.annotate(node_text, xy=parent_pt, xycoords='axes fraction',                         xytext=center_pt, textcoords='axes fraction',                         va='center', ha='center', bbox=node_type, arrowprops=self.arrow_args)class TreePlot(NodePlot):    def create_tree(self, tree):        self.fig = plt.figure(1, facecolor='white') # 生成画板        self.fig.clf()         self.ax = plt.subplot(111, frameon=False)         self.max_width = float(self.get_leaf_nums(tree)) # 计算树的宽度        self.max_depth = float(self.get_tree_depth(tree)) # 计算树的高度        self.x_off = 0.0        self.y_off = 1.0        self.draw_tree(tree, None, '') # 给定初始点,开始绘画决策树        plt.show()    def get_tree_depth(self, tree):        max_depth = 0        first_key = tree.keys()[0]        second_dict = tree[first_key]        for item in second_dict: # 遍历决策树            if isinstance(second_dict[item], dict):                depth = 1 + self.get_tree_depth(second_dict[item])            else:                depth = 1            max_depth = max(max_depth, depth) # 取各个分支节点下的最大高度        return max_depth    def get_leaf_nums(self, tree):        max_leafs = 0        if isinstance(tree, dict):            for item in tree:                max_leafs += self.get_leaf_nums(tree[item])        else:            max_leafs += 1 # 每个叶子节点,宽度加1        return max_leafs    def plot_middle(self, cntr_pt, parent_pt, text):        text_x = (parent_pt[0] - cntr_pt[0]) / 2 + cntr_pt[0]        text_y = (parent_pt[1] - cntr_pt[1]) / 2 + cntr_pt[1]        self.ax.text(text_x, text_y, text)    def draw_tree(self, tree, parent_node, node_text):        width = float(self.get_leaf_nums(tree))        first_key = tree.keys()[0]        cntr_pt = (self.x_off + (1 + width) / 2 / self.max_width, self.y_off)        if parent_node is None:            self.plot_node(first_key, cntr_pt, cntr_pt, self.decision_node)        else:            self.plot_middle(cntr_pt, parent_node, node_text)            self.plot_node(first_key, cntr_pt, parent_node, self.decision_node)        self.y_off -= 1 / self.max_depth        second_item = tree[first_key]        for item in second_item.keys():            if isinstance(second_item[item], dict):                self.draw_tree(second_item[item], cntr_pt, str(item))            else:                self.x_off += 1 / self.max_width                self.plot_node(second_item[item], (self.x_off, self.y_off), cntr_pt, self.leaf_node)                self.plot_middle((self.x_off, self.y_off), cntr_pt, item)        self.y_off += 1 / self.max_depth>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]>>> label = ['x', 'y', 'z']

生成结果如下:
图1-决策树生成图

2.决策树分类

决策树分类
根据决策树分类的一般流程如下:
Created with Raphaël 2.1.0创建数据集根据数据集创建决策树将数据带入决策树数据分类

: 分类代码如下:

def classify(input_tree, labels, test_vec):    first_str = input_tree.keys()[0]    second_dict = input_tree[first_str]    label_index = labels.index(first_str)    class_label = ''    for key in second_dict.keys():        if key == test_vec[label_index]:            if isinstance(second_dict[key], dict):                class_label = classify(second_dict[key], labels, test_vec)            else:                class_label = second_dict[key]    return class_label>>> data = [[1, 2, 'yes'], [2, 3, 'unknown'], [4, 15, 'no'], [12, 2, 'no'], [2, 3, 'yes'], [2, 4, 'unknown']]>>> label = ['x', 'y', 'z']

3.决策树存储

决策树存储
当数据集很大时,每次生成决策树耗费的时间过长,而且训练数据集不常变动时,可以通过pickle模块存储决策树。:
存储代码如下:
import pickledef save_tree(tree):    try:        pickle.dump(tree, 'C:\\tree')        return True    except Exception, e:        return Nonedef get_tree(f):    try:        return pickle.load(f)    except Exception, e:        return None

4.总结

  • 本文介绍了如何通过matplotlib展示决策树,如何使用决策树进行数据分类,如何存储决策树。
  • 下一篇文章会介绍贝叶斯算法。
  • 看过很多大牛的博客,都很钟爱决策树算法,但决策树算法也有坑,例如数据过于杂乱时,会导致树一直生成,树过于冗余,后面介绍如何截断决策树。

5.参考

  • [机器学习实战]
0 0