数据可视化matplotlib(03) 绘制决策树

来源:互联网 发布:大学网络怎么办理 编辑:程序博客网 时间:2024/06/09 06:45

简介

决策树的主要优点是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势。本文将使用matplotlib来绘制树形图,并讲解具体的代码实现。

决策树图实例

为了便于理解,我们先来看看实际的决策树的图长个什么样子。下图所示的流程图就是一个决策树,正方形代表判断模块(decision block), 椭圆形代表终止模块(terminal block),表示已经得出结论,可以终止运行。

代码实现

这里我们给出了相关的代码实现,后面会对一些重要的实现进行详解。

# /usr/bin/python# -*- coding: UTF-8 -*-'''Created on 2017年11月16日@author: bob'''import matplotlib.pyplot as plt# pylint: disable=redefined-outer-name# 定义文本框和箭头格式decision_node = dict(boxstyle="sawtooth", fc="0.8")leaf_node = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def retrieve_tree(i):    list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                     {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                    ]    return list_of_trees[i]def get_num_leafs(mytree):    '''    获取叶子节点数    '''    num_leafs = 0    first_str = mytree.keys()[0]    second_dict = mytree[first_str]        for key in second_dict.keys():        if type(second_dict[key]).__name__ == 'dict':            num_leafs += get_num_leafs(second_dict[key])        else:            num_leafs += 1                return num_leafsdef get_tree_depth(mytree):    '''    获取树的深度    '''    max_depth = 0    first_str = mytree.keys()[0]    second_dict = mytree[first_str]        for key in second_dict.keys():        # 如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用        # get_tree_depth()函数        if type(second_dict[key]).__name__ == 'dict':            this_depth = 1 + get_tree_depth(second_dict[key])        else:            this_depth = 1                    if this_depth > max_depth:            max_depth = this_depth                return max_depthdef plot_node(ax, node_txt, center_ptr, parent_ptr, node_type):    '''        绘制带箭头的注解    '''    ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction',                xytext=center_ptr, textcoords='axes fraction',                va="center", ha="center", bbox=node_type, arrowprops=arrow_args)def plot_mid_text(ax, center_ptr, parent_ptr, txt):    '''    在父子节点间填充文本信息    '''    x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0]    y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1]    ax.text(x_mid, y_mid, txt)def plot_tree(ax, mytree, parent_ptr, node_txt):    '''    绘制决策树    '''    # 计算宽度    num_leafs = get_num_leafs(mytree)        first_str = mytree.keys()[0]    center_ptr = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 / plot_tree.total_width, plot_tree.y_off)        #绘制特征值,并计算父节点和子节点的中心位置,添加标签信息    plot_mid_text(ax, center_ptr, parent_ptr, node_txt)    plot_node(ax, first_str, center_ptr, parent_ptr, decision_node)        second_dict = mytree[first_str]    #采用的自顶向下的绘图,需要依次递减Y坐标    plot_tree.y_off -= 1.0 / plot_tree.total_depth        #遍历子节点,如果是叶子节点,则绘制叶子节点,否则,递归调用plot_tree()    for key in second_dict.keys():        if type(second_dict[key]).__name__ == "dict":            plot_tree(ax, second_dict[key], center_ptr, str(key))        else:            plot_tree.x_off += 1.0 / plot_tree.total_width            plot_mid_text(ax, (plot_tree.x_off, plot_tree.y_off), center_ptr, str(key))            plot_node(ax, second_dict[key], (plot_tree.x_off, plot_tree.y_off), center_ptr, leaf_node)        #在绘制完所有子节点之后,需要增加Y的偏移    plot_tree.y_off += 1.0 / plot_tree.total_depthdef create_plot(in_tree):    fig = plt.figure(1, facecolor="white")    fig.clf()        ax_props = dict(xticks=[], yticks=[])    ax = plt.subplot(111, frameon=False, **ax_props)    plot_tree.total_width = float(get_num_leafs(in_tree))    plot_tree.total_depth = float(get_tree_depth(in_tree))    plot_tree.x_off = -0.5 / plot_tree.total_width    plot_tree.y_off = 1.0    plot_tree(ax, in_tree, (0.5, 1.0), "")#     plot_node(ax, "a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)#     plot_node(ax, "a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)    plt.show()if __name__ == '__main__':#     create_plot()    mytree = retrieve_tree(1)    mytree['no surfacing'][3] = "maybe"    create_plot(mytree)

实现解析

1. 关于注解
matplotlib提供了一个注解工具annotation, 可一在数据图形上添加文本注解。注解通用用于解释数据的内容。工具内嵌支持带箭头的划线工具,可以在恰当的地方指向数据位置,并在此处添加描述信息,解释数据内容。
使用text()会将文本放置在轴域的任意位置。 文本的一个常见用例是标注绘图的某些特征。

def plot_mid_text(ax, center_ptr, parent_ptr, txt):    '''    在父子节点间填充文本信息    '''    x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0]    y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1]    ax.text(x_mid, y_mid, txt)
plot_mid_text()函数实现了在父子节点间绘制文本信息的功能,这个函数中,需要计算父子节点中心位置的坐标,并调用text()函数来进行绘制。
annotate()方法提供辅助函数,使标注变得容易。 在标注中,有两个要考虑的点:由参数xy表示的标注位置和xytext的文本位置。 这两个参数都是(x, y)元组。

# 定义文本框和箭头格式decision_node = dict(boxstyle="sawtooth", fc="0.8")leaf_node = dict(boxstyle="round4", fc="0.8")arrow_args = dict(arrowstyle="<-")def plot_node(ax, node_txt, center_ptr, parent_ptr, node_type):    '''    绘制带箭头的注解    '''    ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction',                xytext=center_ptr, textcoords='axes fraction',                va="center", ha="center", bbox=node_type, arrowprops=arrow_args)
在该示例中,xy(箭头尖端)和xytext位置(文本位置)都以数据坐标为单位。 有多种可以选择的其他坐标系 - 你可以使用xycoords和textcoords以及下列字符串之一(默认为data)指定xy和xytext的坐标系。

| 参数              | 坐标系                             | -----------------------------------------------------------| 'figure points'   | 距离图形左下角的点数量            | | 'figure pixels'   | 距离图形左下角的像素数量           | | 'figure fraction' | 0,0 是图形左下角,1,1 是右上角     | | 'axes points'     | 距离轴域左下角的点数量             | | 'axes pixels'     | 距离轴域左下角的像素数量           |   | 'axes fraction'   | 0,0 是轴域左下角,1,1 是右上角     | | 'data'            | 使用轴域数据坐标系                 |
你可以通过在可选关键字参数arrowprops中提供箭头属性字典来绘制从文本到注释点的箭头。+
|arrowprops键    |描述                                             |---------------------------------------------------------------------|width          |箭头宽度,以点为单位                              ||frac           |箭头头部所占据的比例                              ||headwidth      |箭头的底部的宽度,以点为单位                      ||shrink         |移动提示,并使其离注释点和文本一些距离            ||**kwargs       |matplotlib.patches.Polygon的任何键,例如facecolor |
bbox关键字参数,并且在提供时,在文本周围绘制一个框。根据决策节点还是叶子节点绘制不同的形状。
2. 构造注解树
在绘制注解树时,需要考虑如何放置所有的树节点。我们需要知道有多少个叶节点,以便正确的确定x轴的长度;海需要知道树有多少层,以便正确的确定y轴的高度。在这里,我们使用字典,来存储树节点的信息。

def retrieve_tree(i):    list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},                     {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}                    ]    return list_of_trees[i]def get_num_leafs(mytree):    '''    获取叶子节点数    '''    num_leafs = 0    first_str = mytree.keys()[0]    second_dict = mytree[first_str]        for key in second_dict.keys():        if type(second_dict[key]).__name__ == 'dict':            num_leafs += get_num_leafs(second_dict[key])        else:            num_leafs += 1                return num_leafs
我们来进一步了解如何在Python中存储树的信息。retrieve_tree()里给出了2个实例,参考这两个实例,来看一下get_num_leafs()的实现。从第一个节点出发,可以遍历整棵树的所有子节点。使用type()函数,可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用get_num_leafs()函数。get_num_leafs()遍历整棵树,累计叶子节点个数,并返回该数值。get_tree_depth()函数的实现机制与get_num_leafs()类似。使用retrieve_tree()可以获取树的实例,来测试这个函数的运行是否正确。
plot_tree()函数实现实际树的绘制,同get_num_leafs()函数一样,使用了递归的方式来进行树的各个节点的绘制工作。函数里使用了全局变量,plot_tree.total_width记录树的宽度,plot_tree.total_depth记录树的深度。使用了这两个全局变量来计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。另外两个全局变量plot_tree.x_off和plot_tree.y_off追踪已经绘制的节点位置,以及下一个节点的恰当位置。
create_plot()函数,创建绘图区,计算树形图的全局尺寸,并调用递归函数plot_tree()。

参考资料

1. http://matplotlib.org/api/pyplot_api.html

2. 机器学习实战

原创粉丝点击