决策树03——使用matplotlib绘制树形图并测试算法
2017-06-06 22:56
861 查看
在决策树02——决策树的构建中,我们将已经进行分类的数据存储在字典中,然而字典的表示形式非常不直观,也不容易理解,所以我们将字典中的信息绘制成树形图。
以下将使用Matplotlib的注解功能绘制树形图,它可以对文字着色,并提供多种形状以供选择,而且我们还可以反转箭头,将它指向文本框而不是数据点。
新建名为treeplotter.py的新文件,将输入下面的程序代码:
注意:以上程序运行时会出现中文变成小方框的现象,将以下几行代码添加到文件的开始处。
在命令行输入:
这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的输煤和树的层数。将下面的两个函数添加到treePlotter.py文件中。
函数retrieveTree()输出预先存储的树信息,将 下面代码添加到文件treePlotter.py中:
在命令行中输入:
将下面代码添加到treePlotter.py中,注意前面已经定义了createPlot(),此时我们需要更新前面的代码。
在命令行输入:
注释:
1.
在这行代码中,首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:
其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,
plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW.
2.
这行代码中是需要的,当分支最后一个不是字典的时候,字典循环完需要返回上一层继续进行函数
例如:
3.
在这行代码中,对于plotTree函数参数赋值为(0.5, 1.0),因为开始的根节点并不用划线,因此父节点和当前节点的位置需要重合,利用2中的确定当前节点的位置便为(0.5, 1.0)
总结:利用这样的逐渐增加x的坐标,以及逐渐降低y的坐标能能够很好的将树的叶子节点数和深度考虑进去,因此图的逻辑比例就很好的确定了,这样不用去关心输出图形的大小,一旦图形发生变化,函数会重新绘制,但是假如利用像素为单位来绘制图形,这样缩放图形就比较有难度了
在命令行输入:
在命令行中输入:
Matplotlib注解功能
Matplotlib提供一个注解工具annotations,它可以在数据图形上添加文本注释。以下将使用Matplotlib的注解功能绘制树形图,它可以对文字着色,并提供多种形状以供选择,而且我们还可以反转箭头,将它指向文本框而不是数据点。
新建名为treeplotter.py的新文件,将输入下面的程序代码:
# -*-coding=utf-8 -*- #使用文本朱姐绘制树节点 import matplotlib.pyplot as plt #定义文本框和箭头格式 #定义决策树决策结果的属性(决策节点or叶节点),用字典来定义 #下面的字典定义也可以写作 decisionNode = {boxstyle:’sawtooth‘,fc=’0.8‘} decisionNode = dict(boxstyle = "sawtooth", fc = "0.8") #决策节点,boxstyle为文本框类型,sawtooth是锯齿形,fc是边框内填充的颜色 leafNode = dict(boxstyle = "round",fc="0.8") #叶节点,定义决策树的叶子结点的描述属性 arrow_args = dict(arrowstyle = "<-") #箭头格式 #绘制带箭头的注释 def plotNode(nodeTxt,centerPt,parentPt,nodeType): #nodeTxt是显示的文本,centerPt是文本的中心点,parentPt是箭头的起点坐标,nodeType是一个字典 注解的形状 createPlot.ax1.annotate(nodeTxt,xy = parentPt, xycoords = 'axes fraction', #xy为箭头的起始坐标,0,0 is lower left of axes and 1,1 is upper right xytext = centerPt,textcoords = 'axes fraction', #xytext为注解内容的坐标 va = "center",ha = "center",bbox = nodeType,arrowprops = arrow_args) #bbox注解文本框的形状,arrowprops是指箭头的形状 def createPlot(): fig = plt.figure(1,facecolor='white') #类似于matlab的figure,定义一个画布,其背景为白色 fig.clf() #把画布清空 createPlot.ax1 = plt.subplot(111,frameon=False) # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图, plotNode(U'决策节点',(0.5,0.1),(0.1,0.5), decisionNode) plotNode(U'叶节点',(0.8,0.1),(0.3,0.8), leafNode) plt.show()
注意:以上程序运行时会出现中文变成小方框的现象,将以下几行代码添加到文件的开始处。
from pylab import * mpl.rcParams['font.sans-serif'] = ['SimHei'] #指定默认字体 mpl.rcParams['axes.unicode_minus'] = False
在命令行输入:
In[70]: import treePlotter Backend TkAgg is interactive backend. Turning interactive mode on. In[71]: treePlotter.createPlot()
构造注解树
我们虽然有x, y坐标,但是如何放置所有的树节点却是个问题。我们必须知道有多少个叶节点,以便可以正确确定x轴的长度,我们还需要知道树有多少层,以便可以正确的确定y轴的高度。这里我们定义两个新函数getNumLeafs()和getTreeDepth(),来获取叶节点的输煤和树的层数。将下面的两个函数添加到treePlotter.py文件中。
#获取叶节点的数目和树的层次 def getNumLeafs(myTree): numLeaf = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ =='dict': #测试节点的数据类型是否为字典 ,type(secondDict[key]) ==dict 也是可以的 numLeaf += getNumLeafs(secondDict[key]) else: numLeaf += 1 return numLeaf def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth :maxDepth = thisDepth return maxDepth
函数retrieveTree()输出预先存储的树信息,将 下面代码添加到文件treePlotter.py中:
def retrieveTree(i): listOfTrees = [{'no surfacing':{0:'0',1:{'flippers':{0:'no',1:'yes'}}}}, {'no surfacing': {0: '0', 1: {'flippers': {0: {'head':{0:'no',1:'yes'}}, 1: 'no'}}}} ] return listOfTrees[i]
在命令行中输入:
In[2]: import treePlotter Backend TkAgg is interactive backend. Turning interactive mode on. In[3]: treePlotter.retrieveTree(0) Out[3]: {'no surfacing': {0: '0', 1: {'flippers': {0: 'no', 1: 'yes'}}}} In[4]: myTree = treePlotter.retrieveTree(0) In[5]: treePlotter.getNumLeafs(myTree) Out[5]: 3 In[6]: treePlotter.getTreeDepth(myTree) Out[6]: 2
将下面代码添加到treePlotter.py中,注意前面已经定义了createPlot(),此时我们需要更新前面的代码。
#plotTree函数 #在父子节点间填充文本信息 def plotMidText(cntrPt,parentPt,txtString): xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid,yMid,txtString) #自顶向下作图,绘制图形的x轴有效范围是0.0~1.0, y轴有效范围也是0.0~1.0 def plotTree(myTree,parentPt,nodeTxt): numLeafs = getNumLeafs(myTree) #secondDict[key]的叶节点的数量 depth = getTreeDepth(myTree) #secondDict[key]的树深度 print 'numLeafs,depth:',numLeafs,',',depth firstStr = myTree.keys()[0] # 全局变量plotTree.totalW 存储树的宽度,全局变量PlotTree.totalD 存储树的深度,使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置。 cntrPt = (plotTree.xOff +(1.0 + float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff) #注释1 #标记子节点属性 plotMidText(cntrPt,parentPt,nodeTxt) #这一次循环中的cntrPt(即上式)为cbtrPt,parentPt为上一轮计算出的cntrPt plotNode(firstStr,cntrPt,parentPt,decisionNode) #因还没画到叶节点,所以这里画的是决策节点,即此时筛选secondDict[key]还是字典 secondDict = myTree[firstStr] #计算下一轮要用的y plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #下面的循环中要使用的y for key in secondDict.keys(): if type(secondDict[key]).__name__ == 'dict': plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode) plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #注释2 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[],yticks=[]) #创建一个型为{'xticks': [], 'yticks': []}的字典 createPlot.ax1 = plt.subplot(111,frameon=False,**axprops) plotTree.totalW = float(getNumLeafs(inTree)) #树的宽度 plotTree.totalD = float(getTreeDepth(inTree)) #输的深度 plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #定义节点位置的初始值 plotTree(inTree,(0.5,1.0),'') #(0.5,1.0)为初始化parentPt的值,注释3 plt.show()
在命令行输入:
In[35]: reload(treePlotter) Out[35]: <module 'treePlotter' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/treePlotter.py'> In[36]: myTree = treePlotter.retrieveTree(0) In[37]: myTree Out[37]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} In[38]: treePlotter.createPlot(myTree) numLeafs,depth: 3 , 2 numLeafs,depth: 2 , 1
注释:
1.
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
在这行代码中,首先由于整个画布根据叶子节点数和深度进行平均切分,并且x轴的总长度为1,即如同下图:
其中方形为非叶子节点的位置,@是叶子节点的位置,因此每份即上图的一个表格的长度应该为1/plotTree.totalW,但是叶子节点的位置应该为@所在位置,则在开始的时候plotTree.xOff的赋值为-0.5/plotTree.totalW,即意为开始x位置为第一个表格左边的半个表格距离位置,这样作的好处为:在以后确定@位置时候可以直接加整数倍的1/plotTree.totalW,
plotTree.xOff即为最近绘制的一个叶子节点的x坐标,在确定当前节点位置时每次只需确定当前节点有几个叶子节点,因此其叶子节点所占的总距离就确定了即为float(numLeafs)/plotTree.totalW*1(因为总长度为1),因此当前节点的位置即为其所有叶子节点所占距离的中间即一半为float(numLeafs)/2.0/plotTree.totalW*1,但是由于开始plotTree.xOff赋值并非从0开始,而是左移了半个表格,因此还需加上半个表格距离即为1/2/plotTree.totalW*1,则加起来便为(1.0 + float(numLeafs))/2.0/plotTree.totalW*1,因此偏移量确定,则x位置变为plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW.
2.
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
这行代码中是需要的,当分支最后一个不是字典的时候,字典循环完需要返回上一层继续进行函数
例如:
In[40]: myTree['no surfacing'][3] = 'maybe' In[41]: myTree Out[41]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}, 3: 'maybe'}} In[42]: treePlotter.createPlot(myTree) numLeafs,depth: 4 , 2 numLeafs,depth: 2 , 1
3.
plotTree(inTree,(0.5,1.0),'')
在这行代码中,对于plotTree函数参数赋值为(0.5, 1.0),因为开始的根节点并不用划线,因此父节点和当前节点的位置需要重合,利用2中的确定当前节点的位置便为(0.5, 1.0)
总结:利用这样的逐渐增加x的坐标,以及逐渐降低y的坐标能能够很好的将树的叶子节点数和深度考虑进去,因此图的逻辑比例就很好的确定了,这样不用去关心输出图形的大小,一旦图形发生变化,函数会重新绘制,但是假如利用像素为单位来绘制图形,这样缩放图形就比较有难度了
测试和存储分类器
程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点,最后将测试数据定义为叶子节点所属的类型。#使用决策树的分类算法 def classify(inputTree,featLabels,testVec): #testVec即为需要分类的数据 firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) #将标签字符串转换为索引 print featIndex for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key],featLabels,testVec) else: classLabel = secondDict[key] return classLabel
在命令行输入:
In[19]: reload(trees) Out[19]: <module 'trees' from '/home/vickyleexy/PycharmProjects/Classification of contact lenses/trees.py'> In[20]: myDat,labels = trees.createDataSet() In[21]: myDat Out[21]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] In[22]: labels Out[22]: ['no surfacing', 'flippers'] In[23]: myTree = trees.createTree(myDat,labels) 最好的特征,最好的信息增益: 0 , 0.419973094022 最好的特征,最好的信息增益: 0 , 0.918295834054 In[24]: myDat Out[24]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] In[25]: labels Out[25]: ['flippers'] In[26]: myDat,labels = trees.createDataSet() In[27]: trees.classify(myTree,labels,[1,1]) 0 1 Out[27]: 'yes' In[28]: trees.classify(myTree,labels,[1,0]) 0 1 Out[28]: 'no'
决策树的存储
为了节省时间,最好能够在每次执行分类时调用已经构造好的决策树,使用Python的pickle模块可以在磁盘上保存对象,并在需要的时候读取出来。#使用pickle模块存储决策树 def storeTree(inputTree,filename): import pickle fw = open(filename,'w') pickle.dump(inputTree,fw) fw.close() def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr)
在命令行中输入:
In[29]: trees.storeTree(myTree,'classifierStorage.txt') In[30]: trees.grabTree('classifierStorage.txt') Out[30]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
相关文章推荐
- 第三章 决策树 3.2 使用 Matplotlib 注解绘制树形图
- 《机器学习实战》第三章 3.2在python 中使用matplotlib注解绘制树形图
- 《机器学习实战》第三章 3.2 在Python中使用Matplotlib注解绘制树形图
- 机器学习实战-使用matplotlib绘制决策树
- 数据可视化matplotlib(03) 绘制决策树
- 《机器学习实战》——在python中使用Matplotlib注解绘制树形图
- Python:使用matplotlib绘制图表
- 使用matplotlib绘制图表
- 在Linux下使用Python的matplotlib绘制数据图的教程
- 在Linux下使用Python的matplotlib绘制数据图的教程
- Python:使用matplotlib绘制图表
- 使用Matplotlib绘制正余弦函数、抛物线
- 如何使用matplotlib绘制一个函数的图像
- pyqt中使用matplotlib绘制动态曲线
- pyqt中使用matplotlib绘制动态曲线 – pythonic
- lozi混沌映射吸引子,使用python的matplotlib绘制,可以放大和缩小
- python实战二:使用CSV数据绘制带数据标志的折线图(matplotlib)
- Python:使用matplotlib绘制图表
- 【原】使用Tkinter绘制GUI并结合Matplotlib实现交互式绘图
- python使用matplotlib绘制xy坐标轴图