您的位置:首页 > 编程语言 > PHP开发

数据可视化matplotlib(03) 绘制决策树

2017-11-17 09:33 537 查看

简介

决策树的主要优点是直观易于理解,如果不能将其直观的显示出来,就无法发挥其优势。本文将使用matplotlib来绘制树形图,并讲解具体的代码实现。

决策树图实例

为了便于理解,我们先来看看实际的决策树的图长个什么样子。下图所示的流程图就是一个决策树,正方形代表判断模块(decision block), 椭圆形代表终止模块(terminal block),表示已经得出结论,可以终止运行。



代码实现

这里我们给出了相关的代码实现,后面会对一些重要的实现进行详解。

# /usr/bin/python
# -*- coding: UTF-8 -*-
'''
Created on 2017年11月16日

@author: bob
'''

import matplotlib.pyplot as plt

# pylint: disable=redefined-outer-name

# 定义文本框和箭头格式
decision_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def retrieve_tree(i):
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return list_of_trees[i]

def get_num_leafs(mytree):
'''
获取叶子节点数
'''
num_leafs = 0
first_str = mytree.keys()[0]
second_dict = mytree[first_str]

for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
num_leafs += get_num_leafs(second_dict[key])
else:
num_leafs += 1

return num_leafs

def get_tree_depth(mytree):
'''
获取树的深度
'''
max_depth = 0
d8e0

first_str = mytree.keys()[0]
second_dict = mytree[first_str]

for key in second_dict.keys():
# 如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用
# get_tree_depth()函数
if type(second_dict[key]).__name__ == 'dict':
this_depth = 1 + get_tree_depth(second_dict[key])
else:
this_depth = 1

if this_depth > max_depth:
max_depth = this_depth

return max_depth

def plot_node(ax, node_txt, center_ptr, parent_ptr, node_type):
'''
绘制带箭头的注解
'''
ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction',
xytext=center_ptr, textcoords='axes fraction',
va="center", ha="center", bbox=node_type, arrowprops=arrow_args)

def plot_mid_text(ax, center_ptr, parent_ptr, txt):
'''
在父子节点间填充文本信息
'''
x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0]
y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1]

ax.text(x_mid, y_mid, txt)

def plot_tree(ax, mytree, parent_ptr, node_txt):
'''
绘制决策树
'''
# 计算宽度
num_leafs = get_num_leafs(mytree)

first_str = mytree.keys()[0]
center_ptr = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 / plot_tree.total_width, plot_tree.y_off)

#绘制特征值,并计算父节点和子节点的中心位置,添加标签信息
plot_mid_text(ax, center_ptr, parent_ptr, node_txt)
plot_node(ax, first_str, center_ptr, parent_ptr, decision_node)

second_dict = mytree[first_str]
#采用的自顶向下的绘图,需要依次递减Y坐标
plot_tree.y_off -= 1.0 / plot_tree.total_depth

#遍历子节点,如果是叶子节点,则绘制叶子节点,否则,递归调用plot_tree()
for key in second_dict.keys():
if type(second_dict[key]).__name__ == "dict":
plot_tree(ax, second_dict[key], center_ptr, str(key))
else:
plot_tree.x_off += 1.0 / plot_tree.total_width
plot_mid_text(ax, (plot_tree.x_off, plot_tree.y_off), center_ptr, str(key))
plot_node(ax, second_dict[key], (plot_tree.x_off, plot_tree.y_off), center_ptr, leaf_node)

#在绘制完所有子节点之后,需要增加Y的偏移
plot_tree.y_off += 1.0 / plot_tree.total_depth

def create_plot(in_tree):
fig = plt.figure(1, facecolor="white")
fig.clf()

ax_props = dict(xticks=[], yticks=[])
ax = plt.subplot(111, frameon=False, **ax_props)
plot_tree.total_width = float(get_num_leafs(in_tree))
plot_tree.total_depth = float(get_tree_depth(in_tree))
plot_tree.x_off = -0.5 / plot_tree.total_width
plot_tree.y_off = 1.0
plot_tree(ax, in_tree, (0.5, 1.0), "")
#     plot_node(ax, "a decision node", (0.5, 0.1), (0.1, 0.5), decision_node)
#     plot_node(ax, "a leaf node", (0.8, 0.1), (0.3, 0.8), leaf_node)
plt.show()

if __name__ == '__main__':
#     create_plot()
mytree = retrieve_tree(1)
mytree['no surfacing'][3] = "maybe"
create_plot(mytree)


实现解析

1. 关于注解

matplotlib提供了一个注解工具annotation, 可一在数据图形上添加文本注解。注解通用用于解释数据的内容。工具内嵌支持带箭头的划线工具,可以在恰当的地方指向数据位置,并在此处添加描述信息,解释数据内容。

使用text()会将文本放置在轴域的任意位置。 文本的一个常见用例是标注绘图的某些特征。

def plot_mid_text(ax, center_ptr, parent_ptr, txt):
'''
在父子节点间填充文本信息
'''
x_mid = (parent_ptr[0] - center_ptr[0]) / 2.0 + center_ptr[0]
y_mid = (parent_ptr[1] - center_ptr[1]) / 2.0 + center_ptr[1]

ax.text(x_mid, y_mid, txt)
plot_mid_text()函数实现了在父子节点间绘制文本信息的功能,这个函数中,需要计算父子节点中心位置的坐标,并调用text()函数来进行绘制。

annotate()方法提供辅助函数,使标注变得容易。 在标注中,有两个要考虑的点:由参数xy表示的标注位置和xytext的文本位置。 这两个参数都是(x, y)元组。

# 定义文本框和箭头格式
decision_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plot_node(ax, node_txt, center_ptr, parent_ptr, node_type):
'''
绘制带箭头的注解
'''
ax.annotate(node_txt, xy=parent_ptr, xycoords='axes fraction',
xytext=center_ptr, textcoords='axes fraction',
va="center", ha="center", bbox=node_type, arrowprops=arrow_args)
在该示例中,xy(箭头尖端)和xytext位置(文本位置)都以数据坐标为单位。 有多种可以选择的其他坐标系 - 你可以使用xycoords和textcoords以及下列字符串之一(默认为data)指定xy和xytext的坐标系。

| 参数              | 坐标系                             |
-----------------------------------------------------------
| 'figure points'   | 距离图形左下角的点数量           	 |
| 'figure pixels'   | 距离图形左下角的像素数量           |
| 'figure fraction' | 0,0 是图形左下角,1,1 是右上角     |
| 'axes points'     | 距离轴域左下角的点数量             |
| 'axes pixels'     | 距离轴域左下角的像素数量           |
| 'axes fraction'   | 0,0 是轴域左下角,1,1 是右上角     |
| 'data'            | 使用轴域数据坐标系                 |
你可以通过在可选关键字参数arrowprops中提供箭头属性字典来绘制从文本到注释点的箭头。+
|arrowprops键    |描述                                             |
---------------------------------------------------------------------
|width          |箭头宽度,以点为单位                              |
|frac           |箭头头部所占据的比例                              |
|headwidth      |箭头的底部的宽度,以点为单位                      |
|shrink         |移动提示,并使其离注释点和文本一些距离            |
|**kwargs       |matplotlib.patches.Polygon的任何键,例如facecolor |
bbox关键字参数,并且在提供时,在文本周围绘制一个框。根据决策节点还是叶子节点绘制不同的形状。

2. 构造注解树

在绘制注解树时,需要考虑如何放置所有的树节点。我们需要知道有多少个叶节点,以便正确的确定x轴的长度;海需要知道树有多少层,以便正确的确定y轴的高度。在这里,我们使用字典,来存储树节点的信息。

def retrieve_tree(i):
list_of_trees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return list_of_trees[i]
def get_num_leafs(mytree):
'''
获取叶子节点数
'''
num_leafs = 0
first_str = mytree.keys()[0]
second_dict = mytree[first_str]

for key in second_dict.keys():
if type(second_dict[key]).__name__ == 'dict':
num_leafs += get_num_leafs(second_dict[key])
else:
num_leafs += 1

return num_leafs
我们来进一步了解如何在Python中存储树的信息。retrieve_tree()里给出了2个实例,参考这两个实例,来看一下get_num_leafs()的实现。从第一个节点出发,可以遍历整棵树的所有子节点。使用type()函数,可以判断子节点是否为字典类型。如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用get_num_leafs()函数。get_num_leafs()遍历整棵树,累计叶子节点个数,并返回该数值。get_tree_depth()函数的实现机制与get_num_leafs()类似。使用retrieve_tree()可以获取树的实例,来测试这个函数的运行是否正确。

plot_tree()函数实现实际树的绘制,同get_num_leafs()函数一样,使用了递归的方式来进行树的各个节点的绘制工作。函数里使用了全局变量,plot_tree.total_width记录树的宽度,plot_tree.total_depth记录树的深度。使用了这两个全局变量来计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。另外两个全局变量plot_tree.x_off和plot_tree.y_off追踪已经绘制的节点位置,以及下一个节点的恰当位置。

create_plot()函数,创建绘图区,计算树形图的全局尺寸,并调用递归函数plot_tree()。

参考资料

1. http://matplotlib.org/api/pyplot_api.html
2. 机器学习实战
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐