spark1.2.0源码MLlib --- 决策树-03
来源:互联网 发布:数据漫游关闭还是开启 编辑:程序博客网 时间:2024/06/06 11:04
本章重点关注树中各节点分裂过程中,如何将相应的数据进行汇总,以便之后计算节点不纯度及信息增益,最终确定分裂的顺序。
首先,从 DecisionTree.findBestSplits() 开始,这个方法代码很长,按照执行顺序来看,代码如下:
val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { //节点缓存的情况 input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator } } else { //节点不缓存的情况 input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => //每个节点对应一个聚合器,存储该节点下的统计信息 val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) } new DTStatsAggregator(metadata, featuresForNode) //创建一个节点聚合器 } // iterator all instances in current partition and update aggregate stats points.foreach(binSeqOp(nodeStatsAggregators, _)) //统计该节点下的信息(node,features,bins),放入聚合器中 // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator //转换为kv对,将和其他分区的聚合器进行合并 } }当第一次执行时,只会有一个根节点(节点ID为1),之后,按照树的层级依次递增下去,直到叶子节点为止(或者达到最大的树深度为止)。
继续看关键的一步代码,points.foreach(binSeqOp(nodeStatsAggregators, _)),按分区来聚合,具体代码如下:
def binSeqOp( agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) //根据传递过来的样本值,判断属于哪个节点下 nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) //根据上一步的节点索引,将样本统计信息放入相应的节点聚合器中 } agg }
继续跟踪 nodeBinSeqOp() 方法:
def nodeBinSeqOp( treeIndex: Int, nodeInfo: RandomForest.NodeIndexInfo, agg: Array[DTStatsAggregator], baggedPoint: BaggedPoint[TreePoint]): Unit = { if (nodeInfo != null) { val aggNodeIndex = nodeInfo.nodeIndexInGroup val featuresForNode = nodeInfo.featureSubset val instanceWeight = baggedPoint.subsampleWeights(treeIndex) if (metadata.unorderedFeatures.isEmpty) { orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) //特征属性值为有序的情况 } else { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, //有无序的特征属性值 instanceWeight, featuresForNode) } } }
只看有序的那部分:
private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, instanceWeight: Double, featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label // Iterate over features. if (featuresForNode.nonEmpty) { //节点只使用一部分特征 // Use subsampled features var featureIndexIdx = 0 while (featureIndexIdx < featuresForNode.get.size) { val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) agg.update(featureIndexIdx, binIndex, label, instanceWeight) //更新聚合器的统计信息 featureIndexIdx += 1 } } else { //使用所有特征 // Use all features val numFeatures = agg.metadata.numFeatures var featureIndex = 0 while (featureIndex < numFeatures) { val binIndex = treePoint.binnedFeatures(featureIndex) agg.update(featureIndex, binIndex, label, instanceWeight) //更新聚合器的统计信息 featureIndex += 1 } } }
继续查看 agg.update() 的代码:
def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { val i = featureOffsets(featureIndex) + binIndex * statsSize //该特征索引对应的偏移量位置,在一个统计数组(allStats)中 impurityAggregator.update(allStats, i, label, instanceWeight) }其中impurityAggregator有三种实现类:
val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) //gini系数聚合器,分类使用 case Entropy => new EntropyAggregator(metadata.numClasses) //熵聚合器,分类使用 case Variance => new VarianceAggregator() //方差聚合器,线性回归使用 case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") }
看其中一种 EntropyAggregator ,其update()方法如下:
def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } allStats(offset + label.toInt) += instanceWeight //最后汇总到该数组中 }
*********** The End ***********
0 0
- spark1.2.0源码MLlib --- 决策树-03
- spark1.2.0源码MLlib --- 决策树-01
- spark1.2.0源码MLlib --- 决策树-02
- spark1.2.0源码MLlib --- SVD
- spark1.2.0源码MLlib-线性回归
- spark1.2.0源码MLlib --- 朴素贝叶斯分类器
- Apache Spark MLlib学习笔记(五)MLlib决策树类算法源码解析 1
- Apache Spark MLlib学习笔记(六)MLlib决策树类算法源码解析 2
- Apache Spark MLlib学习笔记(七)MLlib决策树类算法源码解析 3
- spark mllib 决策树算法
- spark1.1.0 Machine Learning Library (MLlib)
- Spark1.6.1 MLlib 特征抽取和变换
- spark1.2.0源码分析之ShuffleMapTask
- Spark MLlib决策树ID3代码
- 编译spark1.6.1源码
- Spark1.6源码编译
- Spark1.6.2 源码编译
- spark1.2源码编译
- Python模块学习 ---- subprocess 创建子进程
- Ajax
- DHCP,DNS,以及网关的意思和存在意义
- 小波分析和尺度函数(下)
- Unity3D之调整画质(贴图)质量
- spark1.2.0源码MLlib --- 决策树-03
- Java之static_final_abstract
- 小波分析和尺度函数
- python写入mysql时候 出现'latin-1' codec can't encode character 问题解决方法 以及python设置utf-8
- HDU-正方形棋盘覆盖问题
- 如何实现通知公告置顶操作
- IoC模式,控制反转与依赖注入
- 用VS2013+VELT-0.1.3进行Linux开发:动态库
- HDU 4165 Pills