CART-分类回归树
2017-02-07 16:00
274 查看
CART简述
cart,分类和回归树算法。cart既可以用来构建分类决策树,也可以用来构建回归树、模型树。
用树对数据建模,把叶子节点简单设定为常数值,构成回归树。如果把叶子节点设定为分段线性函数,即构成模型树。
cart创建分类决策树使用当前数据集中具有最小Gini信息增益的特征作为结点划分决策树。详述可见决策树一节的描述。
回归树,与分类决策树类似,但叶子节点数据类型不是离散型,而是连续型。cart用于回归时,根据叶子是具体值还是连续的机器学习模型又可以分为回归树和模型树。但是无论是回归树还是模型树,其适用场景都是:标签值是连续分布的,但又是可以划分群落的,群落之间是有比较鲜明的区别的,即每个群落内部是相似的连续分布,群落之间分布确是不同的。所以回归树和模型树既算回归,也称得上分类。
cart使用二元切分法来处理连续变量。所以可以固定树的节点,每个节点由4个固定属性:待切分的feature、待切分的feature value、右子树、左子树。
创建树函数createTree():
注:createTree考虑类别>=3时候的代码可参考博客:http://blog.csdn.net/wzmsltw/article/details/51057311
中createTree函数。
*找到最佳的待切分feature、value:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树执行createTree方法
在左子树执行createTree方法*
def createTree(dataSet, leafType, errType, cond=(1,4)): ''' 创建回归树/模型树。 :param dataSet: :param leafType: :param errType: :param cond: 预剪枝条件 :return: ''' feature, value = chooseBestSplit(dataSet, leafType, errType, cond) if feature == None: return value retTree = {} retTree['spInd'] = feature retTree['spVal'] = value lSet,rSet = binSplitDataSet(dataSet, feature, value) retTree['left'] = createTree(lSet, leafType, errType, cond) retTree['right'] = createTree(rSet, leafType, errType, cond) return retTree def binSplitDataSet(dataSet, feature, value): ''' 根据属性feature的特定value划分数据集 :param dataSet: :param feature: :param value: :return: ''' mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:] mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:] return mat0,mat1
对数据的复杂关系建模,我们已决定借用树结构帮助切分数据,那么如何实现数据的切分呢?怎么才知道是否已经切分充分了呢? 这些问题的答案取决于叶节点的建模方式。 回归树假设叶节点是常数值,这种策略认为数据中的复杂关系可以用树结构来概括。 为成功构建以分段常数为叶节点的树,需要度量出数据的一致性。决策树使用树进行分类,会在给定节点时计算数据的混乱度。那么如何计算连续性数值的混乱度?事实上,在数据集上计算混乱度很简单,使用平方误差的总值即总方差。总方差可以通过方差乘以数据中样本点的个数得到。
选择最佳分裂属性的伪代码:
对每个特征:
对每个特征值:
将数据集切分成2份
计算切分的误差
如果当前误差小于最小误差,将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
chooseBestSplit的代码:
def chooseBestSplit(dataSet, leafType, errType, cond=(1,4)): ''' 选择最佳待切分feature以及value :param dataSet: :param leafType:叶子节点的构建方法 :param errType: 总均方差计算方法 :param cond: :return: ''' tolS = cond[0] tolN = cond[1] # dataSet中都属于同一类别,直接返回 if len(set(dataSet[:,-1].T.tolist()[0])) == 1: return None, leafType(dataSet) m,n = shape(dataSet) S = errType(dataSet) bestS = inf; bestIndex = 0; bestValue = 0 for featIndex in range(n-1): for splitVal in set(dataSet[:,featIndex].T.A[0]): mat0,mat1 = binSplitDataSet(dataSet, featIndex, splitVal) if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue; newS = errType(mat0) + errType(mat1) if newS < bestS: bestIndex = featIndex bestValue = splitVal bestS = newS if (S - bestS) < tolS: return None, leafType(dataSet) return bestIndex, bestValue
对于回归树,leafType和errType的方法为:
#取均值 def regLeaf(dataSet): return mean(dataSet[:, -1]) #总方差和 def regErr(dataSet): return var(dataSet[:, -1]) * shape(dataSet)[0]
对于回归树,叶子节点要是分段线性函数,为了找到最佳切分,对于给定的数据集,应该先用线性模型来进行拟合,然后计算真实的目标值与模型预测值间的差值,然后求这些差值的平方和即得到计算误差。
leafType和errType的方法为:
def linearSolve(dataSet): ''' 对dataSet进行线性拟合 :param dataSet: :return: ''' m,n = shape(dataSet) X = mat(ones((m,n))); Y = mat(ones((m,1))) X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:, -1] xTx = X.T * X if linalg.det(xTx) == 0.0: raise NameError('This matrix is singular, cannot do inverse,\n\ try increasing the second value of ops') ws = xTx.I * X.T * Y return ws,X,Y def modelLeaf(dataSet): ws,X,Y = linearSolve(dataSet) return ws def modelErr(dataSet): ws, X, Y = linearSolve(dataSet) yHat = X * ws return sum(power(Y - yHat, 2))
剪枝
剪枝有预剪枝和后剪枝。预剪枝是根据一些原则及早的停止树增长,如树的深度达到用户所要的深度、节点中样本个数少于用户指定个数、不纯度指标下降的最大幅度小于用户指定的幅度等。
后剪枝则是通过在完全生长的树上剪去分枝实现的,通过删除节点的分支来剪去树节点,可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。
chooseBestSplit函数中cond=(1,4)即为预剪枝的条件:
第一个元素表示划分前后方差和之差的阈值,第二个是基于某个特征和value划分后的数据集中样本个数阈值。
后剪枝策略
REP(错误率降低剪枝)该方法考虑将树上每个节点作为修剪的候选对象,决定是否修剪这个节点的步骤:
1、删除以此节点为根的子树
2、使其成为叶子节点
3、赋予该结点关联的训练数据的最常见分类
4、当修剪后的树对于验证集合的性能不会比原来的树差时,才真正删除该结点。
prune伪代码:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算当前两个叶子节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,将叶子节点合并。
代码如下:
def isTree(obj): return (type(obj).__name__ == 'dict') def getMean(tree): if isTree(tree['right']): tree['right'] = getMean(tree['right']) if isTree(tree['left']): tree['left'] = getMean(tree['left']) return (tree['left'] + tree['right'])/2.0 def prune(tree,testData): ''' 降低错误率剪枝 :param tree: :param testData: :return: ''' if shape(testData)[0] == 0: return getMean(tree) lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) if (isTree(tree['left'])): tree['left'] = prune(tree['left'], lSet) if (isTree(tree['right'])): tree['right'] = prune(tree['right'], rSet) # 如果当前节点的left和right节点都是叶子节点 if (not isTree(tree['right'])) and (not isTree(tree['left'])): errorNoMerge = sum(power(lSet[:,-1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2)) treeMean = (tree['left'] + tree['right'])/2.0 errorMerge = sum(power(testData[:, -1] - treeMean, 2)) # 如果合并后的误差比不合并的误差小,则返回合并后的叶子节点 if errorMerge < errorNoMerge: print("merging") return treeMean else: return tree else: return tree
PEP(悲观错误剪枝)
c4.5悲观剪枝代码可参考:
http://blog.csdn.net/o1101574955/article/details/50371499
CCP(代价复杂度剪枝)
PEP和CCP剪枝参考以下博客的剪枝部分介绍:
http://blog.csdn.net/u011067360/article/details/24871801
http://www.cnblogs.com/yonghao/p/5064996.html
回归树和模型树构造完成后,如何使用树模型进行数据预测呢?
思路如下:
从树根节点开始,根据属性feature和feature value判断沿左子树还是右子树向下走。直到叶子节点,如果是回归树直接叶子节点的值即为预测值,如果是回归树,使用叶子节点的线性模型计算得到预测值。
代码如下:
# 回归树预测值函数 def regTreeEval(model, inData): return float(model) # 模型树预测值函数 def modelTreeEval(model, inData): n = shape(inData)[1] X = mat(ones((1, n+1))) X[:,1:n+1] = inData return float(X * model) def treeForecast(tree, inData, modelEval=regTreeEval): ''' 树预测函数 :param tree: :param inData: :param modelEval: :return: ''' if not isTree(tree): return modelEval(tree, inData) if inData[tree['spInd']] > tree['spVal']: if isTree(tree['left']): return treeForecast(tree['left'], inData, modelEval) else: return modelEval(tree['left'], inData) else: if isTree(tree['right']): return treeForecast(tree['right'], inData, modelEval) else: return modelEval(tree['right'], inData) # 预测值得一个工具函数 def createForecast(tree, testData, modelEval=regTreeEval): m = len(testData) yHat = mat(zeros((m,1))) for i in range(m): yHat[i,0] = treeForecast(tree, mat(testData[i]), modelEval) return yHat
参考:
http://blog.csdn.net/wzmsltw/article/details/51057311
完整代码见github:
https://github.com/zhanggw/algorithm/tree/master/machine-learning/CART/cart.py
相关文章推荐
- 浅析时钟向量算法
- 书评:《算法之美( Algorithms to Live By )》
- 动易2006序列号破解算法公布
- C#递归算法之分而治之策略
- Ruby实现的矩阵连乘算法
- C#插入法排序算法实例分析
- C#算法之大牛生小牛的问题高效解决方法
- C#算法函数:获取一个字符串中的最大长度的数字
- 超大数据量存储常用数据库分表分库算法总结
- C#数据结构与算法揭秘二
- C#冒泡法排序算法实例分析
- 算法练习之从String.indexOf的模拟实现开始
- C#算法之关于大牛生小牛的问题
- C#实现的算24点游戏算法实例分析
- 经典排序算法之冒泡排序(Bubble sort)代码
- Android数据加密之异或加密算法的实现方法
- c语言实现的带通配符匹配算法
- 浅析STL中的常用算法
- 算法之排列算法与组合算法详解
- C++实现一维向量旋转算法