分类算法:ID3与C4.5及CART
2016-10-28 13:23
525 查看
原理
ID3算法的介绍网上有很多,它是通过选择能获得最大信息增益的属性来构建决策树。C4.5是通过选择能获得最大信息增益率的属性来构建决策树。
CART用于观察值和输出值都是连续的值的情况,它可以通过选择则最优划分点来做分类;也可以通过将最优划分点改成线性函数(使每次划分时,点均匀分布在函数两侧)来做预测。
要理解信息熵,先要理解熵。熵:当能量均匀分布在物体中时,熵最高。当能量不均匀时,熵最小。带入信息-熵中讲,就是混杂的信息越多,即越混乱说明信息量很大,熵也很大;当信息被提纯后,越来越有序,则信息量很小(因为都是同类),熵也很小。
ID3中最大信息增益就是熵增最大,即使信息摆放更有序的方向。然而使信息摆放更有序的方向往往是选会择有最多取值的那个属性,比如性别和年龄而言更趋向于选有(青,中,老)三个取值的年龄属性。因此在C4.5算法中引入了信息增益率,即用ID3信息增益/信息在当前属性上的不确定度。下面会贴出公式,一看就懂。
信息熵的计算公式
I(s1,s2,...,sm)=−∑i=1mpilog2piID3中信息增益的计算公式
Gain(A)=I(s1,s2,...,sm)−E(A)E(A)=∑j=1vs1j+s2j+s3j+...+smjI(s1j,s2j,s3j,...,smj)
即(按属性A的取值分类前的信息熵)-(按属性A的取值分类后的各子集的信息熵的加权平均值)
C4.5中信息增益率的计算公式
GainRatio(S,A)=Gain(S,A)SplitInfo(S,A)SplitInfo(S,A)=−∑i=1c|si||S|log2(|si||S|)
SplitInfo(S,A)就是用来衡量属性A内部的混乱度,c是属性A的取值个数,si是A按属性取值划分后的一个子集,||是求集合大小。
实例
导入训练/测试数据,并作数据集的预处理构建决策树
loop1
计算每个属性信息增益
返回最优特征
划分子集
删除最优特征属性
遍历每个子集
loop1
持久化树模型
从文件导入树模型
决策
代码
#MyID3.py #-*-coding:utf-8-*- from numpy import * import math import copy import cPickle as pickle class ID3DTree(object): def __init__(self): #construct function self.tree={} #built tree self.dataSet=[] #the dataSet self.labels=[] #the classes def loadDataSet(self,path,labels): #import data function recordlist=[] fp=open(path,"rb") content=fp.read() fp.close() rowlist=content.splitlines() recordlist=[row.split("\t") for row in rowlist if row.strip()] self.dataSet=recordlist self.labels=labels def train(self): #run Decision Tree function labels=copy.deepcopy(self.labels) self.tree=self.buildTree(self.dataSet,labels) def buildTree(self,dataSet,labels): #building Decision Tree,the most important function cateList=[data[-1] for data in dataSet] #遍历每行取最后一列,默认抽取最后一列特征用来做类别的判断 if cateList.count(cateList[0])==len(cateList): #若子集类别只有一种,则返回该类别 return cateList[0] if len(dataSet[0])==1: #若子集只有一列,还需判断该列是否纯净,若不纯净则返回该子集中占比最大的类别 return self.maxCate(cateList) #alogrithm core bestFeat=self.getBestFeat(dataSet) #获取数据集的最优特征列的下标 bestFeatLabel=labels[bestFeat] #根据下标在labels里找到其对应的name tree={bestFeatLabel:{}} #bestFeatLabel作为根节点 {'root':{0:'leaf node',1:{'level2':{0:'leaf node',1:'leaf node'}},2:{'level2':{0:'leaf node',1:'leaf node'}}}} del(labels[bestFeat]) #抽取最优特征列向量 uniqueVals=set([data[bestFeat] for data in dataSet]) #去掉该特征列里的重复值 for value in uniqueVals: subLabels=labels[:] #用删除上层特征列的特征集做子集的特征列集合 splitDataSet=self.splitDataSet(dataSet,bestFeat,value) subTree=self.buildTree(splitDataSet,subLabels) #递归构建子树 tree[bestFeatLabel][value]=subTree #(回溯二层)tree['年龄']['青']=(回溯一层)tree['学生']['是']=(递归到底)tree['买'] 深度优先遍历 return tree def maxCate(self,cateList): #当最后只剩一列特征列但类别仍不纯净时 items=dict([(cateList.count(i),i) for i in cateList]) return items[max(items.keys())] def getBestFeat(self,dataSet): #计算特征向量维度,其中最后一列用于类别标签(买,不买),因此要减去 numFeatures=len(dataSet[0])-1 #用于决策的特征数目 baseEntropy=self.computeEntropy(dataSet) #计算未细分时,当前层数据集的熵 bestInfoGain=0.0 #初始化信息熵增益 bestFeature=-1 #初始化最优特征列 #遍历各特征列,计算信息熵,并计算信息熵增益 for i in xrange(numFeatures): uniqueVals=set([data[i] for data in dataSet]) #保存特征列的值有几种 newEntropy=0.0 #记录按特征列划分后的子数据集的信息熵 for value in uniqueVals: subDataSet=self.splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) #计算特征列里个value占当前dataSet的比率 newEntropy+=prob*self.computeEntropy(subDataSet) #遍历计算特征为i,值为value的子集的熵,最后加权平均 infoGain=baseEntropy-newEntropy #保存按特征列i划分时的信息增益 if(infoGain>bestInfoGain): #记录下最优特征列 bestInfoGain=infoGain bestFeature=i return bestFeature def getBestFeat4_5(self,dataSet): #c4.5算法,按信息增益率选取特征,避免信息增益选择特征时偏向于选取特征值个数较多的情况(特征值越多更利于混乱的减少) Num_Feats=len(dataSet[0][:-1]) #取dataSet第一行,下标从第一个到最后一个(不包括)的子数组求长度 totality=len(dataSet) BaseEntropy=self.computeEntropy(dataSet) #计算当前dataSet的熵 for f in xrange(Num_Feats): featList=[feat[f] for feat in dataSet] #遍历得到一列特征 def computeEntropy(self,dataSet): #计算信息混乱度,越混乱越高 datalen=float(len(dataSet)) cateList=[data[-1] for data in dataSet] #提取要分类的特征列(买,不买) items=dict([(i,cateList.count(i)) for i in cateList]) infoEntropy=0.0 for key in items: prob=float(items[key])/datalen infoEntropy-=prob*math.log(prob,2) return infoEntropy def splitDataSet(self,dataSet,axis,value): #dataSet数据集,axis特征列下标,value特征列的取值之一 rtnList=[] for featVec in dataSet: if(featVec[axis]==value): rFeatVec=featVec[:axis] #list操作,装入0~axis-1间的元素 rFeatVec.extend(featVec[axis+1:]) #list操作,装入axis+1,length-1间的元素 rtnList.append(rFeatVec) return rtnList #序列化函数 def storeTree(self,inputTree,filename): fw=open(filename,'w') pickle.dump(inputTree,fw) fw.close() def grabTree(self,filename): fr=open(filename) return pickle.load(fr) #决策函数 def predict(self,inputTree,featLabels,testVec): root=inputTree.keys()[0] #得到树的根节点对应的特征名字 secondDict=inputTree[root] #得到树根对应各个取值下的子树或节点(节点即最终分类) featIndex=featLabels.index(root) #得到特征名对应的label下标 key=testVec[featIndex] #得到测试集在该特征的取值 valueOfFeat=secondDict[key] #根据key值取得子树或最终类别 if isinstance(valueOfFeat,dict): #判断valueOfFeat是不是字典类,即使不是子树 classLabel=self.predict(valueOfFeat,featLabels,testVec) #递归分类,将子树作为树传下去 else: classLabel=valueOfFeat return classLabel
#treePlotter.py # -*- coding: utf-8 -*- ''' Created on 2015年7月27日 @author: pcithhb ''' import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") #获取叶节点的数目 def getNumLeafs(myTree): numLeafs = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 firstStr = list(myTree.keys())[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点 thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth #绘制节点 def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) #绘制连接线 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) #绘制树结构 def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = list(myTree.keys())[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #创建决策树图形 def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show()
#MyID3_Main.py #-*-coding:utf-8 -*- from numpy import * from MyID3 import * import treePlotter as tp dtree=ID3DTree() dtree.loadDataSet("C:\Users\MCGG\Documents\python\dataset.dat",["age","revenue","student","credit"]) dtree.train() tp.createPlot(dtree.tree) print dtree.tree
相关文章推荐
- 分类算法-----决策树(包括ID3,C4.5)
- ID3决策树与C4.5决策树分类算法简述
- 决策树分类算法:ID3 & C4.5 & CART
- 机器学习自学之路-SVM 算法选择:三种算法优缺点比较(ID3、C4.5、CART)
- 基于决策树系列算法(ID3, C4.5, CART, Random Forest, GBDT)的分类和回归探讨
- 分类算法(5) ---- 决策树(ID3,C4.5,CTAR)
- ID3、C4.5、CART、RandomForest的原理
- 数据挖掘经典算法--CART算法分类和回归树
- 决策树ID3分类算法的C++实现
- 机器学习经典算法详解及Python实现--CART分类决策树、回归树和模型树
- 机器学习算法-决策树生成算法ID3和C4.5
- ID3、C4.5、CART三种决策树的区别
- 机器学习技法-决策树和CART分类回归树构建算法
- 决策树分类算法之ID3
- 经典决策树算法:ID3、C4.5和CART
- Cart文本分类算法原理和例子
- R_针对churn数据用id3、cart、C4.5和C5.0创建决策树模型进行判断哪种模型更合适
- 数据挖掘十大经典算法学习之C4.5决策树分类算法及信息熵相关
- 《统计学习方法》读书笔记-----决策树:ID3,C4.5生成算法和剪枝
- pyhon实现决策树(ID3)算法进行数据的分类预测