用python实现c4.5算法,并进行悲观剪枝
2015-12-21 13:43
579 查看
#coding=utf-8 import xlrd import xlwt import math import operator from datetime import date,datetime from sklearn import datasets ##计算给定数据集的信息熵 def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): #为所有可能分类创建字典 labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * math.log(prob,2) #以2为底数求对数 return shannonEnt ''' #创建数据 def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing', 'flippers'] return dataSet, labels ''' def loaddata(): f=open("train_feature.txt") dataSet=[] for var in f.readlines(): dataSet.append(var[:-2].split(' ')) f.close() with open('400words.txt') as f: labels=f.read().split(' ')[:-1] return dataSet,labels def createDataSet_iris(): iris=datasets.load_iris() dataSet=[] for var in iris.data: dataSet.append(list(var)) targets=iris.target for index,var in enumerate(targets): dataSet[index].append(var) labels=['a','b','c','d'] return dataSet,labels def createDataSet():#导入数据,存入dataSet放课程分数以及该记录中目标课程是否大于75,大于存‘yes’反之‘no’;features放属性,即所有课程 data=xlrd.open_workbook("dataset3.xlsx") from_sheet=data.sheet_by_index(0) id_array=from_sheet.row_values(0) print id_array ncols_length=from_sheet.ncols nrows_length=from_sheet.nrows del id_array[0] del id_array[-1] print id_array features=id_array dataSet=[[]for i in range(0,217)] #print dataSet i=0 for var1 in range(1,from_sheet.nrows): temp=from_sheet.row_values(var1) for var2 in range(1,from_sheet.ncols): if temp[var2]>='75': temp[var2]='1' else: temp[var2]='0' dataSet[var1-1].append(temp[var2]) print dataSet return dataSet,features##返回所有数据以及属性 #依据特征划分数据集 axis代表第几个特征 value代表该特征所对应的值 返回的是划分后的数据集 def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet ''' #ID3中的做法 #选择最好的数据集(特征)划分方式 返回最佳特征下标 def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #特征个数 baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): #遍历特征 第i个 featureSet = set([example[i] for example in dataSet]) #第i个特征取值集合 newEntropy= 0.0 for value in featureSet: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) #该特征划分所对应的entropy infoGain = baseEntropy - newEntropy if infoGain > bestInfoGain: bestInfoGain = infoGain bestFeature = i return bestFeature ''' #选择最好的数据集(特征)划分方式 返回最佳特征下标 def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #特征个数 baseEntropy = calcShannonEnt(dataSet) bestInfoGainrate = 0.0; bestFeature = -1 for i in range(numFeatures): #遍历特征 第i个 featureSet = set([example[i] for example in dataSet]) #第i个特征取值集合 newEntropy= 0.0 splitinfo= 0.0 for value in featureSet: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) #该特征划分所对应的entropy splitinfo -= prob*math.log(prob,2) if not splitinfo: splitinfo=-0.99*math.log(0.99,2)-0.01*math.log(0.01,2) infoGain = baseEntropy - newEntropy infoGainrate = float(infoGain)/float(splitinfo) if infoGainrate > bestInfoGainrate: bestInfoGainrate = infoGainrate bestFeature = i return bestFeature #创建树的函数代码 python中用字典类型来存储树的结构 返回的结果是myTree-字典 def createTree(dataSet, labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList): #类别完全相同则停止继续划分 返回类标签-叶子节点 return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) #遍历完所有的特征时返回出现次数最多的 bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] #得到的列表包含所有的属性值 uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels) return myTree #多数表决的方法决定叶子节点的分类 ---- 当所有的特征全部用完时仍属于多类 def majorityCnt(classList): classCount = {} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0; classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True) #排序函数 operator中的 return sortedClassCount[0][0] #使用决策树执行分类 def classify(inputTree, featLabels, testVec): firstStr = inputTree.keys()[0] secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) #index方法查找当前列表中第一个匹配firstStr变量的元素的索引 for key in secondDict.keys(): if testVec[featIndex] == key: if type(secondDict[key]).__name__ == 'dict': classLabel = classify(secondDict[key], featLabels, testVec) else: classLabel = secondDict[key] return classLabel #决策树的存储 def storeTree(inputTree, filename): #pickle序列化对象,可以在磁盘上保存对象 import pickle fw = open(filename, 'w') pickle.dump(inputTree, fw) fw.close() def grabTree(filename): #并在需要的时候将其读取出来 import pickle fr = open(filename) return pickle.load(fr) # -*- coding: cp936 -*- import matplotlib.pyplot as plt decisionNode = dict(boxstyle = 'sawtooth', fc = '0.8') leafNode = dict(boxstyle = 'round4', fc = '0.8') arrow_args = dict(arrowstyle = '<-') def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',\ xytext = centerPt, textcoords = 'axes fraction',\ va = 'center', ha = 'center', bbox = nodeType, \ arrowprops = arrow_args) # 使用文本注解绘制树节点 def createPlot(): fig = plt.figure(1, facecolor = 'white') fig.clf() createPlot.ax1 = plt.subplot(111, frameon = False) plotNode('a decision node', (0.5,0.1), (0.1,0.5), decisionNode) plotNode('a leaf node', (0.8, 0.1), (0.3,0.8), leafNode) plt.show() #获取叶子节点数目和树的层数 def getNumLeafs(myTree): numLeafs = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if(type(secondDict[key]).__name__ == 'dict'): numLeafs += getNumLeafs(secondDict[key]) else: numLeafs += 1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 firstStr = myTree.keys()[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if(type(secondDict[key]).__name__ == 'dict'): thisDepth = 1+ getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth #更新createPlot代码以得到整棵树 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstStr = myTree.keys()[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it's a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #if you do get a dictonary you know it's a tree, and the first element will be another dict def createPlot(inTree): fig = plt.figure(1, facecolor='white') #fig.title("c4.5",size=14) fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks createPlot.ax1.set_title("c4.5\n",size=24) #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), '') plt.show() def getCount(inputTree,dataSet,featLabels,count): #global num firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) #count=[] for key in secondDict.keys(): rightcount=0 wrongcount=0 tempfeatLabels=featLabels[:] subDataSet=splitDataSet(dataSet,featIndex,key) tempfeatLabels.remove(firstStr) if type(secondDict[key]).__name__=='dict': getCount(secondDict[key],subDataSet,tempfeatLabels,count) #在这里加上剪枝的代码,可以实现自底向上的悲观剪枝 else: for eachdata in subDataSet: if str(eachdata[-1])==str(secondDict[key]): rightcount+=1 else: wrongcount+=1 count.append([rightcount,wrongcount,secondDict[key]]) #num+=rightcount+wrongcount def cutBranch_downtoup(inputTree,dataSet,featLabels,count): #自底向上剪枝 #global num firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): #走到最深的非叶子结点 if type(secondDict[key]).__name__=='dict': tempcount=[] #本将的记录 rightcount=0 wrongcount=0 tempfeatLabels=featLabels[:] subDataSet=splitDataSet(dataSet,featIndex,key) tempfeatLabels.remove(firstStr) getCount(secondDict[key],subDataSet,tempfeatLabels,tempcount) #在这里加上剪枝的代码,可以实现自底向上的悲观剪枝 #计算,并判断是否可以剪枝 #原误差率,显著因子取0.5 tempnum=0.0 wrongnum=0.0 old=0.0 #标准误差 standwrong=0.0 for var in tempcount: tempnum+=var[0]+var[1] wrongnum+=var[1] old=float(wrongnum+0.5*len(tempcount))/float(tempnum) standwrong=math.sqrt(tempnum*old*(1-old)) #假如剪枝 new=float(wrongnum+0.5)/float(tempnum) if new<=old+standwrong and new >=old-standwrong: #要确定新叶子结点的类别 ''' #计算当前各个类别的数量多少,然后,多数类为新叶子结点的类别 tempcount1=0 tempcount2=0 for var in subDataSet: if var[-1]=='0': tempcount1+=1 else: tempcount2+=1 if tempcount1>tempcount2: secondDict[key]='0' else: secondDict[key]='1' ''' #误判率最低的叶子节点的类为新叶子结点的类 #在count的每一个列表类型的元素里再加一个标记类别的元素。 wrongtemp=1.0 newtype=-1 for var in tempcount: if float(var[1]+0.5)/float(var[0]+var[1])<wrongtemp: wrongtemp=float(var[1]+0.5)/float(var[0]+var[1]) newtype=var[-1] secondDict[key]=str(newtype) tempcount=[] #这个相当复杂,因为如果发生剪枝,才会将它置空,如果不发生剪枝,那么应该保持原来的叶子结点的结构 for var in tempcount: count.append(var) for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': continue rightcount=0 wrongcount=0 subDataSet=splitDataSet(dataSet,featIndex,key) for eachdata in subDataSet: if str(eachdata[-1])==str(secondDict[key]): rightcount+=1 else: wrongcount+=1 count.append([rightcount,wrongcount,secondDict[key]]) #最后一个为该叶子结点的类别 def cutBranch_uptodown(inputTree,dataSet,featLabels): #自顶向下剪枝 firstStr=inputTree.keys()[0] secondDict=inputTree[firstStr] featIndex=featLabels.index(firstStr) for key in secondDict.keys(): if type(secondDict[key]).__name__=='dict': tempfeatLabels=featLabels[:] subDataSet=splitDataSet(dataSet,featIndex,key) tempfeatLabels.remove(firstStr) tempcount=[] getCount(secondDict[key],subDataSet,tempfeatLabels,tempcount) print tempcount #计算,并判断是否可以剪枝 #原误差率,显著因子取0.5 tempnum=0.0 wrongnum=0.0 old=0.0 #标准误差 standwrong=0.0 for var in tempcount: tempnum+=var[0]+var[1] wrongnum+=var[1] old=float(wrongnum+0.5*len(tempcount))/float(tempnum) standwrong=math.sqrt(tempnum*old*(1-old)) #假如剪枝 new=float(wrongnum+0.5)/float(tempnum) if new<=old+standwrong and new >=old-standwrong: #要确定新叶子结点的类别 ''' #计算当前各个类别的数量多少,然后,多数类为新叶子结点的类别 tempcount1=0 tempcount2=0 for var in subDataSet: if var[-1]=='0': tempcount1+=1 else: tempcount2+=1 if tempcount1>tempcount2: secondDict[key]='0' else: secondDict[key]='1' ''' #误判率最低的叶子节点的类为新叶子结点的类 #在count的每一个列表类型的元素里再加一个标记类别的元素。 wrongtemp=1.0 newtype=-1 for var in tempcount: if float(var[1]+0.5)/float(var[0]+var[1])<wrongtemp: wrongtemp=float(var[1]+0.5)/float(var[0]+var[1]) newtype=var[-1] secondDict[key]=str(newtype) if __name__ == '__main__': global num num=0 #dataset,features = createDataSet() dataset,features=createDataSet_iris() #dataset,features=loaddata() #print dataset print features features2=features[:] #labels2=labels:这样的赋值只是引用地址的传递,当labels改变时,labels2也会改变。只有labels2=labels[:]这样的才是真正的拷贝 tree = createTree(dataset,features) print tree #print classify(tree,features2,[0,1,1,1,0]) createPlot(tree) count=[] #getCount(tree,dataset,features2,count) #print num #print count #cutBranch_uptodown(tree,dataset,features2) cutBranch_downtoup(tree,dataset,features2,count) createPlot(tree)
在http://blog.csdn.net/Lu597203933/article/details/38024239这篇文章的基础上写的,感谢博主~
实现了两种悲观剪枝方法,一种自底向上,别一种自顶向下。不知道自底向上的方法会不会有问题,因为,如果发生剪枝,设置新的叶子结点后,再向上剪枝,就会失去原决策树的叶子结点的分类信息。
相关文章推荐
- Python获取当地时间---转
- 零基础入门学习Python(16):对象(1)
- Web Scraping with Python 学习笔记7
- python 时间戳 datetime string 转换
- python学习---3
- 【详解】Python带历史数据写Excel表
- 查找算法—斐波那契查找算法分析与实现(Python)
- 利用Python写的xls转sqlite3的导表工具
- Operations on basic server (by Python)
- 【Python】Python读取文件夹下的所有文件
- python中文编码
- python 代码中加入日志系统
- Python --标准库 存储对象 (pickle包,cPickle包)
- Python Requests和urllib两种方式下载网页到本地
- 微信公众帐号自定义菜单创建的python脚本
- 测试python代码在博客的显示
- Python -- 标准库 文件管理 (部分os包,shutil包)
- Python Network Programming(5)---套接字超时时间
- Python Bayes
- 近期用到的python知识列表