spark1.2.0源码MLlib --- 决策树-03
2015-01-27 21:19
477 查看
本章重点关注树中各节点分裂过程中,如何将相应的数据进行汇总,以便之后计算节点不纯度及信息增益,最终确定分裂的顺序。
首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:
继续看关键的一步代码,points.foreach(binSeqOp(nodeStatsAggregators, _)),按分区来聚合,具体代码如下:
继续跟踪 nodeBinSeqOp() 方法:
只看有序的那部分:
继续查看 agg.update() 的代码:
看其中一种 EntropyAggregator ,其update()方法如下:
*********** The End ***********
首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { //节点缓存的情况 input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator } } else { //节点不缓存的情况 input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => //每个节点对应一个聚合器,存储该节点下的统计信息 val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) //创建一个节点聚合器 } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOp(nodeStatsAggregators, _)) //统计该节点下的信息(node,features,bins),放入聚合器中 // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator //转换为kv对,将和其他分区的聚合器进行合并 } }当第一次执行时,只会有一个根节点(节点ID为1),之后,按照树的层级依次递增下去,直到叶子节点为止(或者达到最大的树深度为止)。
继续看关键的一步代码,points.foreach(binSeqOp(nodeStatsAggregators, _)),按分区来聚合,具体代码如下:
def binSeqOp( agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) //根据传递过来的样本值,判断属于哪个节点下 nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) //根据上一步的节点索引,将样本统计信息放入相应的节点聚合器中 } agg }
继续跟踪 nodeBinSeqOp() 方法:
def nodeBinSeqOp( treeIndex: Int, nodeInfo: RandomForest.NodeIndexInfo, agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Unit = { if (nodeInfo != null) { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) //特征属性值为有序的情况 } else { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, //有无序的特征属性值 instanceWeight, featuresForNode) } } }
只看有序的那部分:
private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label // Iterate over features. if (featuresForNode.nonEmpty) { //节点只使用一部分特征 // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.size) { val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) agg.update(featureIndexIdx, binIndex, label, instanceWeight) //更新聚合器的统计信息 featureIndexIdx += 1 } } else { //使用所有特征 // Use all features val numFeatures = agg.metadata.numFeatures var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) agg.update(featureIndex, binIndex, label, instanceWeight) //更新聚合器的统计信息 featureIndex += 1 } } }
继续查看 agg.update() 的代码:
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { val i = featureOffsets(featureIndex) + binIndex * statsSize //该特征索引对应的偏移量位置,在一个统计数组(allStats)中 impurityAggregator.update(allStats, i, label, instanceWeight) }其中impurityAggregator有三种实现类:
val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) //gini系数聚合器,分类使用 case Entropy => new EntropyAggregator(metadata.numClasses) //熵聚合器,分类使用 case Variance => new VarianceAggregator() //方差聚合器,线性回归使用 case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") }
看其中一种 EntropyAggregator ,其update()方法如下:
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } allStats(offset + label.toInt) += instanceWeight //最后汇总到该数组中 }
*********** The End ***********
相关文章推荐
- spark1.2.0源码MLlib --- 决策树-01
- spark1.2.0源码MLlib --- 决策树-02
- spark1.2.0源码MLlib --- 朴素贝叶斯分类器
- 数据可视化matplotlib(03) 绘制决策树
- spark厦大----决策树分类器--spark.mllib
- spark1.2.0源码MLlib-线性回归
- spark1.2.0源码MLlib --- SVD
- 【Spark Mllib】决策树,随机森林——预测森林植被类型
- 决策树03——使用matplotlib绘制树形图并测试算法
- 【Spark你妈喊你回家吃饭-03】Spark RDD的蛮荒世界
- spark机器学习库指南[Spark 1.3.1版]——决策树(decision trees)
- Spark一些常用的数据处理方法-2.MLlib基础统计方法
- spark机器学习(Chapter 03)--使用spark-python进行数据预处理和特征提取
- Spark中决策树源码分析
- 《Spark机器学习》笔记——Spark分类模型(线性回归、朴素贝叶斯、决策树、支持向量机)
- spark.mllib源码阅读-分类算法3-SVM
- spark下线性模型 spark.mllib
- Decision Trees in Apache Spark (Apache Spark中的决策树)
- Spark学习-SparkSQL--03-SparkSQL CLI 建表查询出问题
- spark厦大----分类与回归 - spark.mllib