spark1.2.0源码MLlib --- 决策树-03

首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {  //节点缓存的情况
input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
new DTStatsAggregator(metadata, featuresForNode)

// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
} else {  //节点不缓存的情况
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>   //每个节点对应一个聚合器,存储该节点下的统计信息
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
new DTStatsAggregator(metadata, featuresForNode)  //创建一个节点聚合器

// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOp(nodeStatsAggregators, _))  //统计该节点下的信息(node,features,bins),放入聚合器中

// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator  //转换为kv对,将和其他分区的聚合器进行合并

继续看关键的一步代码,points.foreach(binSeqOp(nodeStatsAggregators, _)),按分区来聚合,具体代码如下:

def binSeqOp(
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)  //根据传递过来的样本值,判断属于哪个节点下
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) //根据上一步的节点索引,将样本统计信息放入相应的节点聚合器中


继续跟踪 nodeBinSeqOp() 方法:

def nodeBinSeqOp(
treeIndex: Int,
nodeInfo: RandomForest.NodeIndexInfo,
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Unit = {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)  //特征属性值为有序的情况
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,  //有无序的特征属性值
instanceWeight, featuresForNode)


private def orderedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
val label = treePoint.label

// Iterate over features.
if (featuresForNode.nonEmpty) {  //节点只使用一部分特征
// Use subsampled features
var featureIndexIdx = 0
while (featureIndexIdx < featuresForNode.get.size) {
val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx))
agg.update(featureIndexIdx, binIndex, label, instanceWeight)  //更新聚合器的统计信息
featureIndexIdx += 1
} else {  //使用所有特征
// Use all features
val numFeatures = agg.metadata.numFeatures
var featureIndex = 0
while (featureIndex < numFeatures) {
val binIndex = treePoint.binnedFeatures(featureIndex)
agg.update(featureIndex, binIndex, label, instanceWeight)  //更新聚合器的统计信息
featureIndex += 1

继续查看 agg.update() 的代码:

def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {
val i = featureOffsets(featureIndex) + binIndex * statsSize  //该特征索引对应的偏移量位置,在一个统计数组(allStats)中
impurityAggregator.update(allStats, i, label, instanceWeight)

val impurityAggregator: ImpurityAggregator = metadata.impurity match {
case Gini => new GiniAggregator(metadata.numClasses)  //gini系数聚合器,分类使用
case Entropy => new EntropyAggregator(metadata.numClasses)  //熵聚合器,分类使用
case Variance => new VarianceAggregator()  //方差聚合器,线性回归使用
case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")

看其中一种 EntropyAggregator ,其update()方法如下:

def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {
if (label >= statsSize) {
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
allStats(offset + label.toInt) += instanceWeight   //最后汇总到该数组中

*********** The End ***********
