您的位置:首页 > 其它

分类算法:ID3与C4.5及CART

2016-10-28 13:23 525 查看

原理

ID3算法的介绍网上有很多,它是通过选择能获得最大信息增益的属性来构建决策树。

C4.5是通过选择能获得最大信息增益率的属性来构建决策树。

CART用于观察值和输出值都是连续的值的情况,它可以通过选择则最优划分点来做分类;也可以通过将最优划分点改成线性函数(使每次划分时,点均匀分布在函数两侧)来做预测

要理解信息熵,先要理解熵。熵:当能量均匀分布在物体中时,熵最高。当能量不均匀时,熵最小。带入信息-熵中讲,就是混杂的信息越多,即越混乱说明信息量很大,熵也很大;当信息被提纯后,越来越有序,则信息量很小(因为都是同类),熵也很小。

ID3中最大信息增益就是熵增最大,即使信息摆放更有序的方向。然而使信息摆放更有序的方向往往是选会择有最多取值的那个属性,比如性别和年龄而言更趋向于选有(青,中,老)三个取值的年龄属性。因此在C4.5算法中引入了信息增益率,即用ID3信息增益/信息在当前属性上的不确定度。下面会贴出公式,一看就懂。

信息熵的计算公式

I(s1,s2,...,sm)=−∑i=1mpilog2pi

ID3中信息增益的计算公式

Gain(A)=I(s1,s2,...,sm)−E(A)

E(A)=∑j=1vs1j+s2j+s3j+...+smjI(s1j,s2j,s3j,...,smj)

即(按属性A的取值分类前的信息熵)-(按属性A的取值分类后的各子集的信息熵的加权平均值)

C4.5中信息增益率的计算公式

GainRatio(S,A)=Gain(S,A)SplitInfo(S,A)

SplitInfo(S,A)=−∑i=1c|si||S|log2(|si||S|)

SplitInfo(S,A)就是用来衡量属性A内部的混乱度,c是属性A的取值个数,si是A按属性取值划分后的一个子集,||是求集合大小。

实例

导入训练/测试数据,并作数据集的预处理

构建决策树

loop1

计算每个属性信息增益

返回最优特征

划分子集

删除最优特征属性

遍历每个子集

loop1

持久化树模型

从文件导入树模型

决策

代码

#MyID3.py
#-*-coding:utf-8-*-
from numpy import *
import math
import copy
import cPickle as pickle
class ID3DTree(object):
def __init__(self): #construct function
self.tree={}    #built tree
self.dataSet=[] #the dataSet
self.labels=[]  #the classes

def loadDataSet(self,path,labels):  #import data function
recordlist=[]
fp=open(path,"rb")
content=fp.read()
fp.close()
rowlist=content.splitlines()
recordlist=[row.split("\t") for row in rowlist if row.strip()]
self.dataSet=recordlist
self.labels=labels

def train(self):    #run Decision Tree function
labels=copy.deepcopy(self.labels)
self.tree=self.buildTree(self.dataSet,labels)

def buildTree(self,dataSet,labels): #building Decision Tree,the most important function
cateList=[data[-1] for data in dataSet] #遍历每行取最后一列,默认抽取最后一列特征用来做类别的判断
if cateList.count(cateList[0])==len(cateList):  #若子集类别只有一种,则返回该类别
return cateList[0]
if len(dataSet[0])==1:  #若子集只有一列,还需判断该列是否纯净,若不纯净则返回该子集中占比最大的类别
return self.maxCate(cateList)
#alogrithm core
bestFeat=self.getBestFeat(dataSet)  #获取数据集的最优特征列的下标
bestFeatLabel=labels[bestFeat]      #根据下标在labels里找到其对应的name
tree={bestFeatLabel:{}} #bestFeatLabel作为根节点 {'root':{0:'leaf node',1:{'level2':{0:'leaf node',1:'leaf node'}},2:{'level2':{0:'leaf node',1:'leaf node'}}}}
del(labels[bestFeat])
#抽取最优特征列向量
uniqueVals=set([data[bestFeat] for data in dataSet])    #去掉该特征列里的重复值
for value in uniqueVals:
subLabels=labels[:] #用删除上层特征列的特征集做子集的特征列集合
splitDataSet=self.splitDataSet(dataSet,bestFeat,value)
subTree=self.buildTree(splitDataSet,subLabels)  #递归构建子树
tree[bestFeatLabel][value]=subTree  #(回溯二层)tree['年龄']['青']=(回溯一层)tree['学生']['是']=(递归到底)tree['买'] 深度优先遍历
return tree

def maxCate(self,cateList): #当最后只剩一列特征列但类别仍不纯净时
items=dict([(cateList.count(i),i) for i in cateList])
return items[max(items.keys())]

def getBestFeat(self,dataSet):
#计算特征向量维度,其中最后一列用于类别标签(买,不买),因此要减去
numFeatures=len(dataSet[0])-1   #用于决策的特征数目
baseEntropy=self.computeEntropy(dataSet)    #计算未细分时,当前层数据集的熵
bestInfoGain=0.0    #初始化信息熵增益
bestFeature=-1      #初始化最优特征列
#遍历各特征列,计算信息熵,并计算信息熵增益
for i in xrange(numFeatures):
uniqueVals=set([data[i] for data in dataSet])   #保存特征列的值有几种
newEntropy=0.0  #记录按特征列划分后的子数据集的信息熵
for value in uniqueVals:
subDataSet=self.splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))    #计算特征列里个value占当前dataSet的比率
newEntropy+=prob*self.computeEntropy(subDataSet)    #遍历计算特征为i,值为value的子集的熵,最后加权平均
infoGain=baseEntropy-newEntropy #保存按特征列i划分时的信息增益
if(infoGain>bestInfoGain):  #记录下最优特征列
bestInfoGain=infoGain
bestFeature=i
return bestFeature

def getBestFeat4_5(self,dataSet): #c4.5算法,按信息增益率选取特征,避免信息增益选择特征时偏向于选取特征值个数较多的情况(特征值越多更利于混乱的减少)
Num_Feats=len(dataSet[0][:-1])  #取dataSet第一行,下标从第一个到最后一个(不包括)的子数组求长度
totality=len(dataSet)
BaseEntropy=self.computeEntropy(dataSet)    #计算当前dataSet的熵
for f in xrange(Num_Feats):
featList=[feat[f] for feat in dataSet]  #遍历得到一列特征

def computeEntropy(self,dataSet):   #计算信息混乱度,越混乱越高
datalen=float(len(dataSet))
cateList=[data[-1] for data in dataSet] #提取要分类的特征列(买,不买)
items=dict([(i,cateList.count(i)) for i in cateList])
infoEntropy=0.0
for key in items:
prob=float(items[key])/datalen
infoEntropy-=prob*math.log(prob,2)
return infoEntropy

def splitDataSet(self,dataSet,axis,value):  #dataSet数据集,axis特征列下标,value特征列的取值之一
rtnList=[]
for featVec in dataSet:
if(featVec[axis]==value):
rFeatVec=featVec[:axis] #list操作,装入0~axis-1间的元素
rFeatVec.extend(featVec[axis+1:])   #list操作,装入axis+1,length-1间的元素
rtnList.append(rFeatVec)
return rtnList

#序列化函数
def storeTree(self,inputTree,filename):
fw=open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()

def grabTree(self,filename):
fr=open(filename)
return pickle.load(fr)

#决策函数
def predict(self,inputTree,featLabels,testVec):
root=inputTree.keys()[0]    #得到树的根节点对应的特征名字
secondDict=inputTree[root]  #得到树根对应各个取值下的子树或节点(节点即最终分类)
featIndex=featLabels.index(root)    #得到特征名对应的label下标
key=testVec[featIndex]  #得到测试集在该特征的取值
valueOfFeat=secondDict[key] #根据key值取得子树或最终类别
if isinstance(valueOfFeat,dict):    #判断valueOfFeat是不是字典类,即使不是子树
classLabel=self.predict(valueOfFeat,featLabels,testVec) #递归分类,将子树作为树传下去
else: classLabel=valueOfFeat
return classLabel


#treePlotter.py
# -*- coding: utf-8 -*-
'''
Created on 2015年7月27日

@author: pcithhb
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

#获取叶节点的数目
def getNumLeafs(myTree):
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点
numLeafs += getNumLeafs(secondDict[key])
else:   numLeafs +=1
return numLeafs

#获取树的层数
def getTreeDepth(myTree):
maxDepth = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#测试节点的数据是否为字典,以此判断是否为叶节点
thisDepth = 1 + getTreeDepth(secondDict[key])
else:   thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth

#绘制节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )

#绘制连接线
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

#绘制树结构
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]     #the text label for this node should be this
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
plotTree(secondDict[key],cntrPt,str(key))        #recursion
else:   #it's a leaf node print the leaf node
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

#创建决策树图形
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
#createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), '')
plt.show()


#MyID3_Main.py
#-*-coding:utf-8 -*-
from numpy import *
from MyID3 import *
import treePlotter as tp
dtree=ID3DTree()
dtree.loadDataSet("C:\Users\MCGG\Documents\python\dataset.dat",["age","revenue","student","credit"])
dtree.train()
tp.createPlot(dtree.tree)
print dtree.tree
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  算法 预测