您的位置:首页 > 其它

机器学习笔记十二:分类与回归树CART

2017-03-22 16:14 281 查看
更新时间:2017.11.18

简化语言,更加通俗









Ⅲ.实现

实现部分采用的数据集是机器学习实战中的数据集.代码则是按照自己的理解重新改写了一遍.

读取数据模块:data.py

import numpy as np
def loadData(filename):
dataSet=np.loadtxt(fname=filename,dtype=np.float32)
return dataSet


用numpy内置的读取txt文件的函数就行,方便快捷.这里就不多讲了.

CART核心模块:CART.py

import numpy as np
import matplotlib.pyplot as plt

#split dataSet trough featureIndex and value
def splitDataSet(dataset,featureIndex,value):
subDataSet0=dataset[dataset[:,featureIndex]<=value,:]
subDataSet1=dataset[dataset[:,featureIndex]>value,:]
return subDataSet0,subDataSet1

#compute the regression Error in a data Set
def getError(dataSet):
error=np.var(dataSet[:,-1])*dataSet.shape[0]
return error

#choose the best featureIndex and value in dataSet
def chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit):
rows,cols=np.shape(dataSet)

#error in dataSet
Error=getError(dataSet)

#init some important value we want get
bestError=np.inf
bestFeatureIndex=0
bestValue=0

#search process
#every feature index
for featureIndex in range(cols-1):
#every value in dataSet of specific index
for value in set(dataSet[:,featureIndex]):
subDataSet0,subDataSet1=splitDataSet(dataSet,featureIndex,value)
#print("sub0",subDataSet0.shape[0])
#print("sub1", subDataSet1.shape[0])

#  print(subDataSet0)
if (subDataSet0.shape[0]<leastNumOfSplit) or (subDataSet1.shape[0]<leastNumOfSplit):
continue
#compute error
tempError=getError(subDataSet0)+getError(subDataSet1)
#print("tempError:",tempError)
if tempError<bestError:
bestError=tempError
bestFeatureIndex=featureIndex
bestValue=value

# print("BestError:", bestError)
# print("BestIndex:", bestFeatureIndex)
# print("BestValue:", bestValue)
if Error-bestError<leastErrorDescent:
return None,np.mean(dataSet[:,-1])
mat0,mat1=splitDataSet(dataSet,bestFeatureIndex,bestValue)
if (mat0.shape[0]<leastNumOfSplit) or (mat1.shape[0]<leastNumOfSplit):
return None,np.mean(dataSet[:,-1])

return bestFeatureIndex,bestValue

#build tree
def buildTree(dataSet,leastErrorDescent=1,leastNumOfSplit=4):
bestFeatureIndex,bestValue=chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)

#recursion termination
if bestFeatureIndex==None:
return bestValue

Tree={}
Tree["featureIndex"]=bestFeatureIndex
Tree["value"]=bestValue
#get subset
leftSet,rightSet=splitDataSet(dataSet,bestFeatureIndex,bestValue)

#recursive function
Tree["left"]=buildTree(leftSet,leastErrorDescent,leastNumOfSplit)
Tree["right"] = buildTree(rightSet, leastErrorDescent, leastNumOfSplit)

return Tree

def isTree(tree):
return (type(tree).__name__=='dict')

def predict(tree,x):
if x[tree["featureIndex"]]<tree["value"]:
if isTree(tree["left"]):
return predict(tree["left"],x)
else:
return tree["left"]

else:
if isTree(tree["right"]):
return predict(tree["right"],x)
else:
return tree["right"]


这里一个一个来讲这些函数.

splitDataSet(dataset,featureIndex,value)


在理论部分已经讲到,我们要划分数据集,只需要两个值,一个就是特征,另外就是指定的阈值.

这个函数的作用就是通过传入的特征和阈值,把数据集划分为两部分.理论部分例子的图就可以形象展示这个函数的作用.

getError(dataSet)


这个函数是用来得到误差的.说是误差,倒不如说是方差.因为理论部分已经给出了式子,其中的c是可以用平均值来替代的,也就是是,刚好是数据集上面的总的方差.

chooseBestSplit(dataSet,leastErrorDescent,leastNumOfSplit)


顾名思义,就是找最好的划分罗.

leastErrorDescent这个参数表示最小的下降误差,也就是说要是在某一刻,误差的下降小于这个值,函数就会退出,leastNumOfSplit表示最小的划分数量.当要划分的集合元素小于这个阈值时候,被认为是没有什么划分的意义了,函数也不会再运行.

然后函数遍历数据集上面所有的特征,与特征上面的所有值,以找到最好的特征索引和划分点返回.

测试文件:run.py

import numpy as np
import data
import CART

dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")

'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))

mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''

#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)

#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)

#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)

x=[1.0,0.559009]
print(CART.predict(tree2,x))


用来测试CART回归的运行代码.

import numpy as np
import data
import CART

dataMat1=data.loadData("../data/ex00.txt")
dataMat2=data.loadData("../data/ex0.txt")

'''
print(dataMat.shape)
print(np.shape(dataMat))
e=CART.getError(dataMat)
print(e)
print(CART.getError(mat0))
print(CART.getError(mat1))

mat0,mat1=CART.splitDataSet(dataMat,0,0.5)
print(mat0)
print(mat1)
print(mat0.shape)
'''

#bestIndex,bestValue=CART.chooseBestSplit(dataMat)
#print(bestIndex,bestValue)

#tree1
tree1=CART.buildTree(dataMat1)
print(tree1)

#tree2
tree2=CART.buildTree(dataMat2)
print(tree2)

x=[1.0,0.559009]
print(CART.predict(tree2,x))


结果:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息