您的位置:首页 > 其它

连续属性的决策树算法实现--基于西瓜3.0数据

2017-07-09 21:55 399 查看
这篇文章主要贴本人在决策树算法学习过程中实践的含连续属性的决策树算法。

语言:Python; 数据集:周志华 西瓜数据3.0

大部分与上篇离散属性决策树相同,一下列出主要的不同部分:

#连续属性的最大增益计算
def InfoGainContous(DatSet,Label,k):
DatSetk = DatSet[:,k]
nk = len(DatSetk)
uniqueDatSetk = list(set(DatSetk))#set不能用索引获取值
uniquesortDatSetk = np.sort(uniqueDatSetk)
n = len(uniquesortDatSetk) #对于set用len方法,set无序
selectPoint = []
for index in range(n-1):
#print index
selectPoint.append((uniquesortDatSetk[index] + uniquesortDatSetk[index + 1])/2.0)
#print 'selectPoint: ',selectPoint
maxinfoEnt = 0.0
bestPoint = -1
bestLabel = []
maxGain = 0
#print 'Label: ',Label
for index in range(n-1):
Label0 = []  #用于存放小于划分点的值
Label1 = []  #用于存放大于划分点的值
labelCount = 0
infoEnt = 0.0
for datindex in range(nk):
if DatSetk[datindex] < selectPoint[index]:
labelCount += 1
Label0.append(Label[datindex])
else: Label1.append(Label[datindex])
sumEnt = len(Label0)/(len(Label)*1.0)*InfoEntCalc(Label0) + len(Label1)/(len(Label)*1.0)*InfoEntCalc(Label1)
infoEnt = InfoEntCalc(Label) - sumEnt
if infoEnt > maxinfoEnt:
maxinfoEnt = infoEnt
bestPoint = selectPoint[index] #得到最佳划分点
bestLabel = Label0
return maxinfoEnt,bestPoint

#计算最大增益
def MaxGain(DatSet,Label,Table):
m,n = np.shape(DatSet)  #多了一些重复计算
Gain = 0.0
maxGain = -1
bestFeature = -1
bestPoint = -1
for tab in Table:
featureNum = list(Table).index(tab)
#print "featureNum: ",featureNum
try:
float(tab)
except:
Gain = InfoGain(DatSet,Label,featureNum)
Point = -1
else:
Gain,Point = InfoGainContous(DatSet,Label,featureNum)
if Gain > maxGain:
bestFeature = featureNum
maxGain = Gain
bestPoint = Point
return bestFeature,bestPoint

#完成基本的决策树构建
def TreeGenerate(Dat,DatOri,Table):  #输入位np array格式
DatSet = Dat[:,:-1]  #取出所有的数据集
Label = Dat[:,-1]   #取出样本对应得类别集
Tables = Table[:]
m,n = np.shape(DatSet)
#当所有数据集的分类相同时:
if list(Label).count(Label[0]) == m:
return Label[0]
#属性集已经遍历完成,但是数据中仍然有多个分类类别时
if n == 1:  #n=1表示只剩下了类别
return majorCnt(Label)
bestFeature,bestPoint = MaxGain(DatSet,Label,Table) #bestFeature对应特征的编号
bestFeatureTable = Table[bestFeature]
del(Table[bestFeature])
#print Table
Tree = {bestFeatureTable:{}}
try:
int(bestFeatureTable)#根据选出的属性是否可以转化为int型确定是否为密度和含糖量
except:
for value in set(DatOri[:,bestFeature]):
#print (bestFeatureTable,value)
subDatSetR = Dat[Dat[:,bestFeature] == value] #选出属性bestFeature,值为value的行
subDatSet = np.concatenate((subDatSetR[:,:bestFeature],subDatSetR[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subDatOri = np.concatenate((DatOri[:,:bestFeature],DatOri[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subTabel = Table[:]
subm,subn = np.shape(subDatSet)
#print subm
#print "Label:", Label
if(subm == 0):  #当子集的数据集为空时,说明没有这样的特征样本,根据其父集中样本最多的类
Tree[bestFeatureTable][value] = majorCnt(Label)#return majorCnt(Label)
else:
Tree[bestFeatureTable][value] = TreeGenerate(subDatSet,subDatOri,subTabel)  #Tree[bestFeature][value]两层深度的树
else:
for value in [-1,1]: #-1表示小于划分点的情况;1表示大于划分点的情况
if value == -1:
subDatSetR = Dat[Dat[:,bestFeature] < bestPoint] #选出属性bestFeature,值为value的行
subDatSet = np.concatenate((subDatSetR[:,:bestFeature],subDatSetR[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subDatOri = np.concatenate((DatOri[:,:bestFeature],DatOri[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subTabel = Table[:]
subm,subn = np.shape(subDatSet)
strval = '<' + str(bestPoint)
if(subm == 0):  #当子集的数据集为空时,说明没有这样的特征样本,根据其父集中样本最多的类
Tree[bestFeatureTable][strval] = majorCnt(Label)#return majorCnt(Label)
else:
Tree[bestFeatureTable][strval] = TreeGenerate(subDatSet,subDatOri,subTabel)  #Tree[bestFeature][value]两层深度的树
if value == 1:
subDatSetR = Dat[Dat[:,bestFeature] >= bestPoint] #选出属性bestFeature,值为value的行
subDatSet = np.concatenate((subDatSetR[:,:bestFeature],subDatSetR[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subDatOri = np.concatenate((DatOri[:,:bestFeature],DatOri[:,bestFeature+1:]),axis=1) #数据集将bestFeature属性去掉
subTabel = Table[:]
subm,subn = np.shape(subDatSet)
strval = '>=' + str(bestPoint)
if(subm == 0):  #当子集的数据集为空时,说明没有这样的特征样本,根据其父集中样本最多的类
Tree[bestFeatureTable][strval] = majorCnt(Label)#return majorCnt(Label)
else:
Tree[bestFeatureTable][strval] = TreeGenerate(subDatSet,subDatOri,subTabel)  #Tree[bestFeature][value]两层深度的树

return Tree
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐