您的位置:首页 > 编程语言 > Python开发

《MachineLearningInAction》之绘制决策树

2017-03-25 18:51 459 查看
《MachineLearningInAction》(Peter Harrington)中的代码有点小问题,我重写了全书所有代码,分享于此。

Block Ⅰ

import matplotlib.pyplot as plt #用于调用绘图
import matplotlib #用于调用rcParams属性,设置绘图窗口风格


Block Ⅱ

定义全局变量
decisionNode #决策节边框样式
leafNode #叶子节点边框样式
arrow_args #箭头样式


Block Ⅲ
定义函数

retrieveTree #为简化问题及做函数测试,手动生成大小不一的Tree
getNumLeafs #获取Tree的叶子数
getTreeDepth #获取Tree的深度,即decisionNode的个数
plotNode #绘制节点,通过nodeType参数区分decisionNode及leafNode
plotMidText #annotate每一个dict的key
plotTree #迭代绘制决策树


Block Ⅴ

测试代码

当__name__ == '__main__'时,即作为主模块调用时执行

效果图



关键思路:

1、迭代生成整棵树,代码测试时从只有一个decisionNode开始(通过调用retrieveTree(0)获得)。

2、通过参数传递plot axis来在同一个轴上绘图。Peter通过在实时调用时给plotTree函数增加axis属性达到同样效果,稍显复杂。

3、整张图绘制在(0,0),(1,1)围成的矩形区域内绘制,第一个decisionNode中心位于(0.5,1),通过decisionNode及LeafNode的数目控制纵向及横向间距,此二者皆为定值。这一点不知道是否与Peter的思路一致,因为他的代码太令我眼花缭乱,没看,我全部重写的。

4、plotTree先绘制decisionNode及其于parentNode之间的箭头,特别地当中心坐标与父节点坐标相等时,系统函数不绘制箭头。

treePlotter源码如下:

# -*- coding: utf-8 -*-
"""
treePlotter.py
~~~~~~~~~~

A module with functions to plot decision tree.

Created on Thu Mar 23 17:26:57 2017

Run on Python 3.6

@author: Luo Shaozhuo

refer to 'MachineLearninginAction'

"""
#==============================================================================
# import
#==============================================================================
import matplotlib.pyplot as plt
import matplotlib

#==============================================================================
# Global variables
#==============================================================================
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#==============================================================================
# functions
#==============================================================================
def retrieveTree(i=0):
"""
return a predefined tree
~~~~~~~~~~
i: must be 0 or 1. 1 for a taller tree
~~~~~~~~~~
dictTree
"""
listOfTrees =[{'no surfacing': {0: 'no', 1: 'yes'}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
]
return listOfTrees[i]

def getNumLeafs(dictTree):
"""
return the number of leafs
~~~~~~~~~~
dictTree: a dictonary dipicting a decidion tree
~~~~~~~~~~
nNumLeaf: number of leafs
"""
nNumLeaf = 0
for key in dictTree.keys():
if type(dictTree[key]) == dict:
nNumLeaf += getNumLeafs(dictTree[key])
else:   nNumLeaf +=1
return nNumLeaf

def getTreeDepth(dictTree):
"""
return the tree depth
~~~~~~~~~~
dictTree: a dictonary dipicting a decidion tree
~~~~~~~~~~
nMaxDepth: tree depth
"""
nMaxDepth = 0
keys = list(dictTree.keys())[0]
dictTrunk = dictTree[keys]
for key in dictTrunk.keys():
if type(dictTrunk[key]) == dict:
nCurDepth = 1 + getTreeDepth(dictTrunk[key])
else:
nCurDepth = 1
if nCurDepth > nMaxDepth:
nMaxDepth = nCurDepth
return nMaxDepth

def plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, nodeType):
"""
plot a decision node or a leaf node depend on nodeType.
~~~~~~~~~~
pltAxis: plot axis
strNodeTxt: text in node box
tplCntrPt: center coordinates of box
tplPrntPt: starting coordinates of arrow
nodeType: leafNode or decisionNode
~~~~~~~~~~
N/A
"""
pltAxis.annotate(strNodeTxt, xy=tplPrntPt, xycoords='axes fraction',
xytext=tplCntrPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotMidText(pltAxis, cntrPt, parentPt, txtString):
"""
add feature value in the middle of arrow
~~~~~~~~~~
cntrPt:
parentPt:
txtString:
~~~~~~~~~~
N/A
"""
xMid = (parentPt[0]+cntrPt[0])/2.0
yMid = (parentPt[1]+cntrPt[1])/2.0
pltAxis.text(xMid, yMid, txtString)

def plotTree(dictTree, pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt, strNodeTxt):
"""
plot tree recursivly
~~~~~~~~~~
dictTree: decision tree
pltAxis: axis used for plotting
fTrunkLen: difference of y coordinates between two decision nodes
fBrchLen: difference of y coordinates between two leafs
tplCntrPt: coordinates of parent node
strNodeTxt: text in node box
~~~~~~~~~~
N/A
"""
#plot root node
plotNode(pltAxis, strNodeTxt, tplCntrPt, tplPrntPt, decisionNode)
#plot branch node
tplPrntPt = tplCntrPt
nNumKey = len(dictTree.keys())
fMean = sum([x for x in range(nNumKey)])/nNumKey
for i,key in enumerate(dictTree.keys()):
tplCntrPt = (tplPrntPt[0]+(i-fMean)*fBrchLen, tplPrntPt[1]-fTrunkLen)
plotMidText(pltAxis, tplCntrPt, tplPrntPt, key)
if type(dictTree[key]) == dict:
strNodeTxt = list(dictTree[key].keys())[0]
plotTree(dictTree[key][strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)
else:
strNodeTxt = dictTree[key]
plotNode(pltAxis,strNodeTxt, tplCntrPt, tplPrntPt, leafNode)

if __name__ == '__main__':
dictTree = retrieveTree(2)
matplotlib.rcParams['toolbar'] = 'none'
pltAxis = plt.subplot(111, frameon=False,xticks=[], yticks=[])
fBrchLen = 1/getNumLeafs(dictTree)
fTrunkLen= 1/getTreeDepth(dictTree)
tplCntrPt = (0.5,1)
tplPrntPt = tplCntrPt
strNodeTxt = list(dictTree.keys())[0]
plotTree(dictTree[strNodeTxt], pltAxis, fTrunkLen, fBrchLen, tplCntrPt, tplPrntPt,strNodeTxt)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息