您的位置:首页 > 其它

spark1.2.0源码MLlib --- 决策树-01

2015-01-26 19:50 106 查看
决策树可以分两类:分类树和回归树,分别用于分类模型和线性回归模型。

首先,看一下spark中的使用案例,代码如下:

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val numClasses = 2
val categoricalFeaturesInfo = Map[Int, Int]()   //当map为空时,表示特征属性为连续的情况
val impurity = "gini"   //以gini指标作为节点不纯度的度量
val maxDepth = 5   //树的最大深度
val maxBins = 32   //每个特征分裂时,最大的属性数目(一般是特征属性连续的情况)

val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,  //分类模型
impurity, maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelAndPreds = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification tree model:\n" + model.toDebugString)


建立回归模型,代码如下:

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.util.MLUtils

// Load and parse the data file.
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
val splits = data.randomSplit(Array(0.7, 0.3))
val (trainingData, testData) = (splits(0), splits(1))

// Train a DecisionTree model.
//  Empty categoricalFeaturesInfo indicates all features are continuous.
val categoricalFeaturesInfo = Map[Int, Int]()
val impurity = "variance"  //使用方差作为节点不纯度的度量
val maxDepth = 5
val maxBins = 32

val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity,  //线性回归模型
maxDepth, maxBins)

// Evaluate model on test instances and compute test error
val labelsAndPredictions = testData.map { point =>
val prediction = model.predict(point.features)
(point.label, prediction)
}
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression tree model:\n" + model.toDebugString)


在spark中,决策树涉及到的代码比较多,因此,将分篇幅讲几个重要的地方。本章只关注主线的代码。

(看分类模型那部分代码) 跟踪 DecisionTree.trainClassifier(),代码如下:

def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo)
}


继续跟踪下去:

def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,  //计算特征属性的值时,默认使用sort策略,即将属性值排序
categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
}
跟踪run()方法:

def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)  //底层用的是随机森林的一套实现,默认只有一棵树,且每个分裂的节点将使用所有的特征
val rfModel = rf.run(input)
rfModel.trees(0)
}
接着看 rf.run()的实现:

def run(input: RDD[LabeledPoint]): RandomForestModel = {

val timer = new TimeTracker()

timer.start("total")

timer.start("init")

val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) //建立决策树的元数据信息(分裂点位置及箱子数等等),每个箱子包含特征属性的值
logDebug("algo = " + strategy.algo)
logDebug("numTrees = " + numTrees)
logDebug("seed = " + seed)
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)

// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)  //取得分裂点位置和箱子的信息
timer.stop("findSplitsBins")
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
s"\t$featureIndex\t${metadata.numBins(featureIndex)}"
}.mkString("\n"))

// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)  //转换成树形的RDD类型,此时,所有样本点已经按分裂点条件分到了各自的箱子中

val (subsample, withReplacement) = {
// TODO: Have a stricter check for RF in the strategy
val isRandomForest = numTrees > 1
if (isRandomForest) {  //随机森林,多颗树的情况
(1.0, true)  //true代表数据有放回采样
} else {
(strategy.subsamplingRate, false)
}
}

val baggedInput
= BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) //重新封装一层,如果是随机森林,每棵树就是样本的一个子集
.persist(StorageLevel.MEMORY_AND_DISK)

// depth of the decision tree
val maxDepth = strategy.maxDepth
require(maxDepth <= 30,    //树的深度不能大于30
s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")

// Max memory usage for aggregates
// TODO: Calculate memory usage more precisely.
val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L  //可以使用的最大内存,当RDD聚合操作时
logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
val maxMemoryPerNode = {
val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { //是否是使用特征的一个子集
// Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
.take(metadata.numFeaturesPerNode).map(_._2))
} else {
None
}
RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L  //计算每个节点在聚合操作时分配的内存大小
}
require(maxMemoryPerNode <= maxMemoryUsage,
s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
" which is too small for the given features." +
s"  Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")

timer.stop("init")

/*
* The main idea here is to perform group-wise training of the decision tree nodes thus
* reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
* Each data sample is handled by a particular node (or it reaches a leaf and is not used
* in lower levels).
*/

// Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
val nodeIdCache = if (strategy.useNodeIdCache) {  //节点是否使用缓存,节点ID从1开始,1即为这颗树的根节点,左节点为2,右节点为3,依次递增下去
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
None
}

// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()  //节点队列,先进先出

val rng = new scala.util.Random()
rng.setSeed(seed)

// Allocate and queue root nodes.
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))  //创建树的根节点
Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))  //入队,(树的索引,数的根节点),树索引从0开始,根节点从1开始

while (nodeQueue.nonEmpty) {
// Collect some nodes to split, and choose features for each node (if subsampling).
// Each group of nodes may come from one or multiple trees, and at multiple levels.
val (nodesForGroup, treeToNodeToIndexInfo) =
RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)  //取得每棵树需要分裂的节点们
// Sanity check (should never occur):
assert(nodesForGroup.size > 0,
s"RandomForest selected empty nodesForGroup.  Error for unknown reason.")

// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)  //选择最好的分裂顺序
timer.stop("findBestSplits")
}

baggedInput.unpersist()

timer.stop("total")

logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")

// Delete any remaining checkpoints used for node Id cache.
if (nodeIdCache.nonEmpty) {
nodeIdCache.get.deleteAllCheckpoints()
}

val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
new RandomForestModel(strategy.algo, trees)
}


*********** The End ***********
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: