spark1.2.0源码MLlib --- 决策树-01
2015-01-26 19:50
106 查看
决策树可以分两类:分类树和回归树,分别用于分类模型和线性回归模型。
首先,看一下spark中的使用案例,代码如下:
建立回归模型,代码如下:
在spark中,决策树涉及到的代码比较多,因此,将分篇幅讲几个重要的地方。本章只关注主线的代码。
(看分类模型那部分代码) 跟踪 DecisionTree.trainClassifier(),代码如下:
继续跟踪下去:
*********** The End ***********
首先,看一下spark中的使用案例,代码如下:
import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split the data into training and test sets (30% held out for testing) val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a DecisionTree model. // Empty categoricalFeaturesInfo indicates all features are continuous. val numClasses = 2 val categoricalFeaturesInfo = Map[Int, Int]() //当map为空时,表示特征属性为连续的情况 val impurity = "gini" //以gini指标作为节点不纯度的度量 val maxDepth = 5 //树的最大深度 val maxBins = 32 //每个特征分裂时,最大的属性数目(一般是特征属性连续的情况) val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, //分类模型 impurity, maxDepth, maxBins) // Evaluate model on test instances and compute test error val labelAndPreds = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() println("Test Error = " + testErr) println("Learned classification tree model:\n" + model.toDebugString)
建立回归模型,代码如下:
import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.util.MLUtils // Load and parse the data file. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Split the data into training and test sets (30% held out for testing) val splits = data.randomSplit(Array(0.7, 0.3)) val (trainingData, testData) = (splits(0), splits(1)) // Train a DecisionTree model. // Empty categoricalFeaturesInfo indicates all features are continuous. val categoricalFeaturesInfo = Map[Int, Int]() val impurity = "variance" //使用方差作为节点不纯度的度量 val maxDepth = 5 val maxBins = 32 val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, //线性回归模型 maxDepth, maxBins) // Evaluate model on test instances and compute test error val labelsAndPredictions = testData.map { point => val prediction = model.predict(point.features) (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() println("Test Mean Squared Error = " + testMSE) println("Learned regression tree model:\n" + model.toDebugString)
在spark中,决策树涉及到的代码比较多,因此,将分篇幅讲几个重要的地方。本章只关注主线的代码。
(看分类模型那部分代码) 跟踪 DecisionTree.trainClassifier(),代码如下:
def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = { val impurityType = Impurities.fromString(impurity) train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort, categoricalFeaturesInfo) }
继续跟踪下去:
def train( input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, maxDepth: Int, numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, //计算特征属性的值时,默认使用sort策略,即将属性值排序 categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) new DecisionTree(strategy).run(input) }跟踪run()方法:
def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) //底层用的是随机森林的一套实现,默认只有一棵树,且每个分裂的节点将使用所有的特征 val rfModel = rf.run(input) rfModel.trees(0) }接着看 rf.run()的实现:
def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() timer.start("total") timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) //建立决策树的元数据信息(分裂点位置及箱子数等等),每个箱子包含特征属性的值 logDebug("algo = " + strategy.algo) logDebug("numTrees = " + numTrees) logDebug("seed = " + seed) logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) //取得分裂点位置和箱子的信息 timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) //转换成树形的RDD类型,此时,所有样本点已经按分裂点条件分到了各自的箱子中 val (subsample, withReplacement) = { // TODO: Have a stricter check for RF in the strategy val isRandomForest = numTrees > 1 if (isRandomForest) { //随机森林,多颗树的情况 (1.0, true) //true代表数据有放回采样 } else { (strategy.subsamplingRate, false) } } val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) //重新封装一层,如果是随机森林,每棵树就是样本的一个子集 .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth require(maxDepth <= 30, //树的深度不能大于30 s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates // TODO: Calculate memory usage more precisely. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L //可以使用的最大内存,当RDD聚合操作时 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val maxMemoryPerNode = { val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { //是否是使用特征的一个子集 // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. Some(metadata.numBins.zipWithIndex.sortBy(- _._1) .take(metadata.numFeaturesPerNode).map(_._2)) } else { None } RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L //计算每个节点在聚合操作时分配的内存大小 } require(maxMemoryPerNode <= maxMemoryUsage, s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + " which is too small for the given features." + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") timer.stop("init") /* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). * Each data sample is handled by a particular node (or it reaches a leaf and is not used * in lower levels). */ // Create an RDD of node Id cache. // At first, all the rows belong to the root nodes (node Id == 1). val nodeIdCache = if (strategy.useNodeIdCache) { //节点是否使用缓存,节点ID从1开始,1即为这颗树的根节点,左节点为2,右节点为3,依次递增下去 Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, checkpointDir = strategy.checkpointDir, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { None } // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() //节点队列,先进先出 val rng = new scala.util.Random() rng.setSeed(seed) // Allocate and queue root nodes. val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) //创建树的根节点 Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) //入队,(树的索引,数的根节点),树索引从0开始,根节点从1开始 while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) //取得每棵树需要分裂的节点们 // Sanity check (should never occur): assert(nodesForGroup.size > 0, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) //选择最好的分裂顺序 timer.stop("findBestSplits") } baggedInput.unpersist() timer.stop("total") logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { nodeIdCache.get.deleteAllCheckpoints() } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) new RandomForestModel(strategy.algo, trees) }
*********** The End ***********
相关文章推荐
- spark1.2.0源码MLlib --- 决策树-02
- spark1.2.0源码MLlib --- 决策树-03
- spark1.2.0源码MLlib --- 朴素贝叶斯分类器
- spark1.2.0源码MLlib-线性回归
- spark1.2.0源码MLlib --- SVD
- spark厦大----决策树分类器--spark.mllib
- 【Spark Mllib】决策树,随机森林——预测森林植被类型
- Spark MLlib(二)SVM
- 【Eclipse 01】MyEclipse项目中的构建路径和类路径lib的问题
- 机器学习实战python版决策树以及Matplotlib注解绘制决策树
- spark.mllib源码阅读-优化算法3-Optimizer
- spark.mllib源码阅读-回归算法1-LinearRegression
- libspark,不懂日文怎么学……
- spark厦大----分类与回归 - spark.mllib
- 01_spark1.3_RDD的开发
- Tigase-01 使用spark或spi登录Tigase服务器
- Spark中决策树源码分析
- Spark2.0机器学习系列之3:决策树及Spark 2.0-MLlib、Scikit代码分析
- 决策树03——使用matplotlib绘制树形图并测试算法
- MLlib回归算法(线性回归、决策树)实战演练--Spark学习(机器学习)