spark mllib源码分析之随机森林(Random Forest)(三)
来源:互联网 发布:淘宝上钻石多少好评 编辑:程序博客网 时间:2024/06/16 09:08
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)
6. 随机森林训练
6.1. 数据结构
6.1.1. Node
树中的每个节点是一个Node结构
class Node @Since("1.2.0") ( @Since("1.0.0") val id: Int, @Since("1.0.0") var predict: Predict, @Since("1.2.0") var impurity: Double, @Since("1.0.0") var isLeaf: Boolean, @Since("1.0.0") var split: Option[Split], @Since("1.0.0") var leftNode: Option[Node], @Since("1.0.0") var rightNode: Option[Node], @Since("1.0.0") var stats: Option[InformationGainStats])
emptyNode,只初始化nodeIndex,其他都是默认值
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, false, None, None, None, None)
根据node的id,计算孩子节点的id
* Return the index of the left child of this node. */ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 /** * Return the index of the right child of this node. */ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1
左孩子节点就是当前id * 2,右孩子是id * 2+1。
6.1.2. Entropy
6.1.2.1. Entropy
Entropy是个Object,里面最重要的是calculate函数
/** * :: DeveloperApi :: * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { return 0 } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 while (classIndex < numClasses) { val classCount = counts(classIndex) if (classCount != 0) { val freq = classCount / totalCount impurity -= freq * log2(freq) } classIndex += 1 } impurity }
熵的计算公式
因此这里的入参count是各class的出现的次数,先计算出现概率,然后取log累加。
6.1.2.2. EntropyAggregator
class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses)
只有一个成员变量class的个数,关键是update函数
/** * Update stats for one (node, feature, bin) with the given label. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } if (label < 0) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s"but requires label is non-negative.") } allStats(offset + label.toInt) += instanceWeight }
offset是特征值偏移,加上label就是该class在allStats里的位置,累加出现的次数
/** * Get an [[ImpurityCalculator]] for a (node, feature, bin). * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) }
截取allStats中属于该特征的split的部分数组,长度是statSize,也就是class数
6.1.2.3. EntropyCalculator
/** * Calculate the impurity from the stored sufficient statistics. */ def calculate(): Double = Entropy.calculate(stats, stats.sum)
结合上面的函数可以看到,计算entropy的路径是调用Entropy的getCalculator函数,里面截取allStats中属于该split的部分,然后实际调用Entropy的calculate函数计算熵。
这里还重载了prob函数,主要是返回label的概率,例如0的统计有3个,1的统计7个,则label 0的概率就是0.3.
6.1.3. DTStatsAggregator
这里啰嗦下node分裂时需要怎样统计,这与DTStatsAggregator的设计是相关的。以使用信息熵为例,node分裂时,迭代每个特征的每个split,这个split会把样本集分成两部分,要计算entropy,需要分别统计左/右部分class的分布情况,然后计算概率,进而计算entropy,因此aggregator中statsSize等于numberclasses,同时allStats里记录了所有的统计值,实际这个统计值就是class的分布情况
class DTStatsAggregator( val metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. */ val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) case Variance => new VarianceAggregator() case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } /** * Number of elements (Double values) used for the sufficient statistics of each bin. */ private val statsSize: Int = impurityAggregator.statsSize /** * Number of bins for each feature. This is indexed by the feature index. */ private val numBins: Array[Int] = { if (featureSubset.isDefined) { featureSubset.get.map(metadata.numBins(_)) } else { metadata.numBins } } /** * Offset for each feature for calculating indices into the [[allStats]] array. */ private val featureOffsets: Array[Int] = { numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } /** * Total number of elements stored in this aggregator */ private val allStatsSize: Int = featureOffsets.last /** * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize * Note: For unordered features, * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize)
每个node有一个DTStatsAggregator,构造函数接受2个参数,metadata和node使用的特征子集。其他的类成员
- impurityAggregator:目前支持Gini,Entropy和Variance,后面我们以Entropy为例,其他类似
- statsSize:每个bin需要的统计数,分类时等于numClasses,因为于每个class都需要单独统计;回归等于3,分别存着特征值个数,特征值sum,特征值平方和,为计算variance
- numBins:node所用特征对应的numBins数组元素
- featureOffsets:计算特征在allStats中的index,与每个特征的bin个数和statsSize有关,例如我们有3个特征,其bins分别为3,2,2,statsSize为2,则第一个特征需要的bin的个数是3 * 2=6,2 * 2=4,2 * 2=4,则featureOffsets为0,6,10,14,是从左到右的累计值
- allStatsSize:需要的桶的个数
- allStats:存储统计值的桶
f0,f1,f2是3个特征,f0有3个特征值(其实是binIndex)0/1/2,f1有2个0/1,f2有2个0/1,每个特征值都有statsSize个状态桶,因此共14个,个数allStatsSize=14, 比如我们想在f1的v1的c1的index,就是从featureOffsets中取得f1的特征偏移量featureOffsets(1)=6,v1的binIndex相当于是1,statsSize是2,其label是1,则桶的index=6+1*2+1=9,恰好是图中f1v1的c1的桶的index
我们对其中的关键函数进行说明
/** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { //第一部分是特征偏移 //binIndex相当于特征内特征值的偏移,每个特征有statsSize个桶,因此两者相加就是这个特征值对应的桶 //例如Entropy的update函数,里面再加上label.toInt就是这个label的桶 //从这里特征偏移的计算可以看出ordered特征其特征值最好是连续的,中间无间断,并且必须从0开始 //当然如果有间断,这里相当于浪费部分空间 val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed * (node, feature, left/right child) offset from * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { //偏移的计算同上,不过这里特征偏移是入参给出的,不需要再计算 impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) }
6.2. 训练初始化
// FIFO queue of nodes to train: (treeIndex, node)val nodeQueue = new mutable.Queue[(Int, Node)]()val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
构造了numTrees个Node,赋默认值emptyNode,这些node将作为每棵树的root node,参与后面的训练。将这些node与treeIndex封装加入到队列nodeQueue中,后面会将所有待split的node都加入到这个队列中,依次split,直到所有node触发截止条件,也就是后面的while循环中队列为空了。
6.3. 选择待分裂node
这部分逻辑在selectNodesToSplit中,主要是从nodeQueue中取出本轮需要分裂的node,并计算node的参数。
/** * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. * This tracks the memory usage for aggregates and stops adding nodes when too much memory * will be needed; this allows an adaptive number of nodes since different nodes may require * different amounts of memory (if featureSubsetStrategy is not "all"). * * @param nodeQueue Queue of nodes to split. * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. * * treeToNodeToIndexInfo holds indices selected features for each node: * treeIndex --> (global) node index --> (node index in group, feature indices). * The (global) node index is the index in the tree; the node index in group is the * index in [0, numNodesInGroup) of the node in this group. * The feature indices are None if not subsampling features. */ private[tree] def selectNodesToSplit( nodeQueue: mutable.Queue[(Int, Node)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() val mutableTreeToNodeToIndexInfo = new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() var memUsage: Long = 0L var numNodesInGroup = 0 while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { val (treeIndex, node) = nodeQueue.head //用蓄水池抽样(之前的文章有介绍)对node使用的特征集抽样 // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } // Check if enough memory remains to add this node to the group. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage) { nodeQueue.dequeue() mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node mutableTreeToNodeToIndexInfo .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = new NodeIndexInfo(numNodesInGroup, featureSubset) } numNodesInGroup += 1 memUsage += nodeMemUsage } // Convert mutable maps to immutable ones. val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap (nodesForGroup, treeToNodeToIndexInfo) }
代码比较简单明确,受限于内存,将本次能够处理的node从nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。
是否对特征集进行抽样的条件是metadata的 numFeatures是否等于numFeaturesPerNode,这两个参数是metadata的入参,在buildMetadata时,根据featureSubsetStrateg确定,参见前文。
nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node数组,其中放着该tree本次要分裂的node。
treeToNodeToIndexInfo的类型是Map[Int, Map[Int, NodeIndexInfo]],key为treeIndex,value中Map的key是node.id,这个id来自Node初始化时的第一个参数,第一轮时node的id都是1。其value为NodeIndexInfo结构,
class NodeIndexInfo( val nodeIndexInGroup: Int, val featureSubset: Option[Array[Int]])
第一个成员是此node在本次node选择的while循环中的index,称为groupIndex,第二个成员是特征子集。
- spark mllib源码分析之随机森林(Random Forest)(三)
- spark mllib源码分析之随机森林(Random Forest)(一)
- spark mllib源码分析之随机森林(Random Forest)(二)
- spark mllib源码分析之随机森林(Random Forest)(四)
- spark mllib源码分析之随机森林(Random Forest)(五)
- 随机森林(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林回归(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林(Random Forest)
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林(Random Forest)
- 随机森林Random Forest
- 随机森林--Random Forest
- 随机森林Random Forest
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林Random Forest
- Random Forest(随机森林)
- socket IP地址转换函数
- 【UML】机房收费系统十种图示例
- 编写RESTful风格的程序
- 动态规划题——最短路径问题
- 栈与队列
- spark mllib源码分析之随机森林(Random Forest)(三)
- 242. Valid Anagram
- HTML标签元素
- Ubuntu 配置 Tomcat
- 图论-思维题-hdu6029-Graph Theory
- STL容器:vector
- Linux下MongoDB权限管理
- S5PV210 DATASHEET Section2.6 booting sequence 翻译
- 编码问题:ASCII、Unicode和UTF-8