您的位置:首页 > 编程语言

CART 回归树代码实现

2017-05-27 09:53 239 查看
from numpy import *
def loadData(fileName):
retMat=[]
fr=open(fileName)
for line in fr.readlines():
curline=line.strip().split('\t')
curline=list(map(float,curline))
retMat.append(curline)
return mat(retMat)
#树节点信息
class treeNode():
def __init__(self,feat,value,leftChild=None,rightChild=None):
self.feat=feat
self.value=value
self.lc=leftChild
self.rc=rightChild

#树节点,后面打算用先序遍历来打印树


#建立回归树,把框架写出来
def splitData(dataSet,feat,val):
mat0=[]
mat1=[]
n=shape(data)[0]
for  j in range(n):
if(data[j,feat]>val):
mat0.append([data[j,0],data[j,1]])
else:
mat1.append([data[j,0],data[j,1]])
return mat(mat0),mat(mat1)

def createTree(dataSet):
#计算最佳切分特征及切分点,用函数写出来
feat,val=chooseBestSplit(dataSet)
if feat==None:return treeNode(None,val)

node=treeNode(feat,val)
leftMat,rightMat=splitData(dataSet,feat,val)
node.lc=createTree(leftMat)
node.rc=createTree(rightMat)
return node

def calcValue(dataSet):
return mean(dataSet[:,-1])
def calcError(dataSet):
m=shape(dataSet)[0]
return var(dataSet[:,-1])*m

#树的节点应该保存成为平均值


def chooseBestSplit(dataSet,op=(1,4)):
m,n=shape(dataSet)
tolS=op[0]
tolN=op[1]
#判断这个是不是已分为一个类
if len(set(dataSet[:,-1].T.tolist()[0]))==1:
return None,calcValue(dataSet)
S=calcError(dataSet)
bestS=inf
bestIndex=0
bestVal=0
for featIndex in range(n-1):
for splitVal in dataSet[:,featIndex]:
mat0,mat1=splitData(dataSet,featIndex,splitVal)

if shape(mat0)[0]<tolN or shape(mat1)[0]<tolN: continue

newS=calcError(mat0)+calcError(mat1)
if newS<bestS:
bestIndex=featIndex
bestValue=splitVal
bestS=newS

if S-bestS<tolS:
return None,calcValue(dataSet)
mat0,mat1=splitData(dataSet,bestIndex,bestValue)
# print(shape(mat0)[0],shape(mat1)[0])
if shape(mat0)[0]<tolN or shape(mat1)[0]<tolN:
return None,calcValue(dataSet)
return bestIndex,bestValue

#core code


data=loadData('ex00.txt')
node=createTree(data)
#树的先序遍历
def tree(node):
if node==None:
return
print(node.feat,":",node.value)
tree(node.lc)
tree(node.rc)
return
tree(node)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: