数据可视化matplotlib(03) 绘制决策树
2017-11-17 09:33
537 查看
简介
决策树的主要优点是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势。本文将使用matplotlib来绘制树形图,并讲解具体的代码实现。决策树图实例
为了便于理解,我们先来看看实际的决策树的图长个什么样子。下图所示的流程图就是一个决策树,正方形代表判断模块(decision block), 椭圆形代表终止模块(terminal block),表示已经得出结论,可以终止运行。代码实现
这里我们给出了相关的代码实现,后面会对一些重要的实现进行详解。# /usr/bin/python # -*- coding: UTF-8 -*- ''' Created on 2017年11月16日 @author: bob ''' import matplotlib.pyplot as plt # pylint: disable=redefined-outer-name # 定义文本框和箭头格式 decision_node = dict(boxstyle="sawtooth", fc="0.8") leaf_node = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def retrieve_tree(i): list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return list_of_trees[i] def get_num_leafs(mytree): ''' 获取叶子节点数 ''' num_leafs = 0 first_str = mytree.keys()[0] second_dict = mytree[first_str] for key in second_dict.keys(): if type(second_dict[key]).__name__ == 'dict': num_leafs += get_num_leafs(second_dict[key]) else: num_leafs += 1 return num_leafs def get_tree_depth(mytree): ''' 获取树的深度 ''' max_depth = 0 d8e0 first_str = mytree.keys()[0] second_dict = mytree[first_str] for key in second_dict.keys(): # 如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用 # get_tree_depth()函数 if type(second_dict[key]).__name__ == 'dict': this_depth = 1 + get_tree_depth(second_dict[key]) else: this_depth = 1 if this_depth > max_depth: max_depth = this_depth return max_depth def plot_node(ax, node_txt, center_ptr, parent_ptr, node_type): ''' 绘制带箭头的注解 ''' ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction', xytext=center_ptr, textcoords='axes fraction', va="center", ha="center", bbox=node_type, arrowprops=arrow_args) def plot_mid_text(ax, center_ptr, parent_ptr, txt): ''' 在父子节点间填充文本信息 ''' x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0] y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1] ax.text(x_mid, y_mid, txt) def plot_tree(ax, mytree, parent_ptr, node_txt): ''' 绘制决策树 ''' # 计算宽度 num_leafs = get_num_leafs(mytree) first_str = mytree.keys()[0] center_ptr = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 / plot_tree.total_width, plot_tree.y_off) #绘制特征值,并计算父节点和子节点的中心位置,添加标签信息 plot_mid_text(ax, center_ptr, parent_ptr, node_txt) plot_node(ax, first_str, center_ptr, parent_ptr, decision_node) second_dict = mytree[first_str] #采用的自顶向下的绘图,需要依次递减Y坐标 plot_tree.y_off -= 1.0 / plot_tree.total_depth #遍历子节点,如果是叶子节点,则绘制叶子节点,否则,递归调用plot_tree() for key in second_dict.keys(): if type(second_dict[key]).__name__ == "dict": plot_tree(ax, second_dict[key], center_ptr, str(key)) else: plot_tree.x_off += 1.0 / plot_tree.total_width plot_mid_text(ax, (plot_tree.x_off, plot_tree.y_off), center_ptr, str(key)) plot_node(ax, second_dict[key], (plot_tree.x_off, plot_tree.y_off), center_ptr, leaf_node) #在绘制完所有子节点之后,需要增加Y的偏移 plot_tree.y_off += 1.0 / plot_tree.total_depth def create_plot(in_tree): fig = plt.figure(1, facecolor="white") fig.clf() ax_props = dict(xticks=[], yticks=[]) ax = plt.subplot(111, frameon=False, **ax_props) plot_tree.total_width = float(get_num_leafs(in_tree)) plot_tree.total_depth = float(get_tree_depth(in_tree)) plot_tree.x_off = -0.5 / plot_tree.total_width plot_tree.y_off = 1.0 plot_tree(ax, in_tree, (0.5, 1.0), "") # plot_node(ax, "a decision node", (0.5, 0.1), (0.1, 0.5), decision_node) # plot_node(ax, "a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node) plt.show() if __name__ == '__main__': # create_plot() mytree = retrieve_tree(1) mytree['no surfacing'][3] = "maybe" create_plot(mytree)
实现解析
1. 关于注解matplotlib提供了一个注解工具annotation, 可一在数据图形上添加文本注解。注解通用用于解释数据的内容。工具内嵌支持带箭头的划线工具,可以在恰当的地方指向数据位置,并在此处添加描述信息,解释数据内容。
使用text()会将文本放置在轴域的任意位置。 文本的一个常见用例是标注绘图的某些特征。
def plot_mid_text(ax, center_ptr, parent_ptr, txt): ''' 在父子节点间填充文本信息 ''' x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0] y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1] ax.text(x_mid, y_mid, txt)plot_mid_text()函数实现了在父子节点间绘制文本信息的功能,这个函数中,需要计算父子节点中心位置的坐标,并调用text()函数来进行绘制。
annotate()方法提供辅助函数,使标注变得容易。 在标注中,有两个要考虑的点:由参数xy表示的标注位置和xytext的文本位置。 这两个参数都是(x, y)元组。
# 定义文本框和箭头格式 decision_node = dict(boxstyle="sawtooth", fc="0.8") leaf_node = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def plot_node(ax, node_txt, center_ptr, parent_ptr, node_type): ''' 绘制带箭头的注解 ''' ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction', xytext=center_ptr, textcoords='axes fraction', va="center", ha="center", bbox=node_type, arrowprops=arrow_args)在该示例中,xy(箭头尖端)和xytext位置(文本位置)都以数据坐标为单位。 有多种可以选择的其他坐标系 - 你可以使用xycoords和textcoords以及下列字符串之一(默认为data)指定xy和xytext的坐标系。
| 参数 | 坐标系 | ----------------------------------------------------------- | 'figure points' | 距离图形左下角的点数量 | | 'figure pixels' | 距离图形左下角的像素数量 | | 'figure fraction' | 0,0 是图形左下角,1,1 是右上角 | | 'axes points' | 距离轴域左下角的点数量 | | 'axes pixels' | 距离轴域左下角的像素数量 | | 'axes fraction' | 0,0 是轴域左下角,1,1 是右上角 | | 'data' | 使用轴域数据坐标系 |你可以通过在可选关键字参数arrowprops中提供箭头属性字典来绘制从文本到注释点的箭头。+
|arrowprops键 |描述 | --------------------------------------------------------------------- |width |箭头宽度,以点为单位 | |frac |箭头头部所占据的比例 | |headwidth |箭头的底部的宽度,以点为单位 | |shrink |移动提示,并使其离注释点和文本一些距离 | |**kwargs |matplotlib.patches.Polygon的任何键,例如facecolor |bbox关键字参数,并且在提供时,在文本周围绘制一个框。根据决策节点还是叶子节点绘制不同的形状。
2. 构造注解树
在绘制注解树时,需要考虑如何放置所有的树节点。我们需要知道有多少个叶节点,以便正确的确定x轴的长度;海需要知道树有多少层,以便正确的确定y轴的高度。在这里,我们使用字典,来存储树节点的信息。
def retrieve_tree(i): list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}, {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}} ] return list_of_trees[i] def get_num_leafs(mytree): ''' 获取叶子节点数 ''' num_leafs = 0 first_str = mytree.keys()[0] second_dict = mytree[first_str] for key in second_dict.keys(): if type(second_dict[key]).__name__ == 'dict': num_leafs += get_num_leafs(second_dict[key]) else: num_leafs += 1 return num_leafs我们来进一步了解如何在Python中存储树的信息。retrieve_tree()里给出了2个实例,参考这两个实例,来看一下get_num_leafs()的实现。从第一个节点出发,可以遍历整棵树的所有子节点。使用type()函数,可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用get_num_leafs()函数。get_num_leafs()遍历整棵树,累计叶子节点个数,并返回该数值。get_tree_depth()函数的实现机制与get_num_leafs()类似。使用retrieve_tree()可以获取树的实例,来测试这个函数的运行是否正确。
plot_tree()函数实现实际树的绘制,同get_num_leafs()函数一样,使用了递归的方式来进行树的各个节点的绘制工作。函数里使用了全局变量,plot_tree.total_width记录树的宽度,plot_tree.total_depth记录树的深度。使用了这两个全局变量来计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。另外两个全局变量plot_tree.x_off和plot_tree.y_off追踪已经绘制的节点位置,以及下一个节点的恰当位置。
create_plot()函数,创建绘图区,计算树形图的全局尺寸,并调用递归函数plot_tree()。
参考资料
1. http://matplotlib.org/api/pyplot_api.html2. 机器学习实战
相关文章推荐
- Python进阶(三十八)-数据可视化の利用matplotlib 进行折线图,直方图和饼图的绘制
- 【数据可视化】Daft:(Python)基于matplotlib绘制精美概率图模型
- 决策树03——使用matplotlib绘制树形图并测试算法
- matplotlib模块数据可视化-绘制散列图
- matplotlib模块数据可视化-绘制柱状图
- python 数据可视化 matplotlib学习一:绘制简单的折线图
- 教程 | 如何优雅而高效地使用Matplotlib实现数据可视化
- 数据可视化matplotlib的应用(二)
- python—matplotlib数据可视化实例注解系列-----设置标注字体样式(matplotlib颜色库)
- 西瓜书 习题4.3 编程实现信息熵决策树、绘制决策树、解决matplotlib中文乱码问题
- python——数据可视化:matplotlib,seaborn,pandas
- 数据可视化matplotlib的应用
- Python数据可视化matplotlib(二)—— 子图功能
- 动态可视化 数据可视化之魅D3,Processing,pandas数据分析,科学计算包Numpy,可视化包Matplotlib,Matlab语言可视化的工作,Matlab没有指针和引用是个大问题
- 用matplotlib实现数据可视化之线形图(函数)
- Python进阶(三十九)-数据可视化の使用matplotlib进行绘图分析数据
- 【Matplotlib】数据可视化实例分析
- Python数据可视化:Matplotlib 直方图、箱线图、条形图、热图、折线图、散点图。。。
- matplotlib使用自己的数据绘制k线图
- matplotlib模块数据可视化-饼状图及补充图