spark mllib源码分析之随机森林(Random Forest)(三)

来源:互联网 发布:淘宝上钻石多少好评 编辑:程序博客网 时间:2024/06/16 09:08

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(四)
spark源码分析之随机森林(Random Forest)(五)

6. 随机森林训练

6.1. 数据结构

6.1.1. Node

树中的每个节点是一个Node结构

class Node @Since("1.2.0") (    @Since("1.0.0") val id: Int,    @Since("1.0.0") var predict: Predict,    @Since("1.2.0") var impurity: Double,    @Since("1.0.0") var isLeaf: Boolean,    @Since("1.0.0") var split: Option[Split],    @Since("1.0.0") var leftNode: Option[Node],    @Since("1.0.0") var rightNode: Option[Node],    @Since("1.0.0") var stats: Option[InformationGainStats])

emptyNode,只初始化nodeIndex,其他都是默认值

def emptyNode(nodeIndex: Int): Node =     new Node(nodeIndex, new Predict(Double.MinValue),    -1.0, false, None, None, None, None)

根据node的id,计算孩子节点的id

   * Return the index of the left child of this node.   */  def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1  /**   * Return the index of the right child of this node.   */  def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1

左孩子节点就是当前id * 2,右孩子是id * 2+1。

这里写图片描述

6.1.2. Entropy

6.1.2.1. Entropy

Entropy是个Object,里面最重要的是calculate函数

/**   * :: DeveloperApi ::   * information calculation for multiclass classification   * @param counts Array[Double] with counts for each label   * @param totalCount sum of counts for all labels   * @return information value, or 0 if totalCount = 0   */  @Since("1.1.0")  @DeveloperApi  override def calculate(counts: Array[Double], totalCount: Double): Double = {    if (totalCount == 0) {      return 0    }    val numClasses = counts.length    var impurity = 0.0    var classIndex = 0    while (classIndex < numClasses) {      val classCount = counts(classIndex)      if (classCount != 0) {        val freq = classCount / totalCount        impurity -= freq * log2(freq)      }      classIndex += 1    }    impurity  }

熵的计算公式

H=E[logpi]=i=1npilogpi

因此这里的入参count是各class的出现的次数,先计算出现概率,然后取log累加。

6.1.2.2. EntropyAggregator
class EntropyAggregator(numClasses: Int)  extends ImpurityAggregator(numClasses)

只有一个成员变量class的个数,关键是update函数

/**   * Update stats for one (node, feature, bin) with the given label.   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.   * @param offset    Start index of stats for this (node, feature, bin).   */  def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = {    if (label >= statsSize) {      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +        s" but requires label < numClasses (= $statsSize).")    }    if (label < 0) {      throw new IllegalArgumentException(s"EntropyAggregator given label $label" +        s"but requires label is non-negative.")    }    allStats(offset + label.toInt) += instanceWeight  }

offset是特征值偏移,加上label就是该class在allStats里的位置,累加出现的次数

/**   * Get an [[ImpurityCalculator]] for a (node, feature, bin).   * @param allStats  Flat stats array, with stats for this (node, feature, bin) contiguous.   * @param offset    Start index of stats for this (node, feature, bin).   */  def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {    new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)  }

截取allStats中属于该特征的split的部分数组,长度是statSize,也就是class数

6.1.2.3. EntropyCalculator
/**   * Calculate the impurity from the stored sufficient statistics.   */  def calculate(): Double = Entropy.calculate(stats, stats.sum)

结合上面的函数可以看到,计算entropy的路径是调用Entropy的getCalculator函数,里面截取allStats中属于该split的部分,然后实际调用Entropy的calculate函数计算熵。
这里还重载了prob函数,主要是返回label的概率,例如0的统计有3个,1的统计7个,则label 0的概率就是0.3.

6.1.3. DTStatsAggregator

这里啰嗦下node分裂时需要怎样统计,这与DTStatsAggregator的设计是相关的。以使用信息熵为例,node分裂时,迭代每个特征的每个split,这个split会把样本集分成两部分,要计算entropy,需要分别统计左/右部分class的分布情况,然后计算概率,进而计算entropy,因此aggregator中statsSize等于numberclasses,同时allStats里记录了所有的统计值,实际这个统计值就是class的分布情况

class DTStatsAggregator(    val metadata: DecisionTreeMetadata,    featureSubset: Option[Array[Int]]) extends Serializable {  /**   * [[ImpurityAggregator]] instance specifying the impurity type.   */  val impurityAggregator: ImpurityAggregator = metadata.impurity match {    case Gini => new GiniAggregator(metadata.numClasses)    case Entropy => new EntropyAggregator(metadata.numClasses)    case Variance => new VarianceAggregator()    case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}")  }  /**   * Number of elements (Double values) used for the sufficient statistics of each bin.   */  private val statsSize: Int = impurityAggregator.statsSize  /**   * Number of bins for each feature.  This is indexed by the feature index.   */  private val numBins: Array[Int] = {    if (featureSubset.isDefined) {      featureSubset.get.map(metadata.numBins(_))    } else {      metadata.numBins    }  }  /**   * Offset for each feature for calculating indices into the [[allStats]] array.   */  private val featureOffsets: Array[Int] = {    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)  }  /**   * Total number of elements stored in this aggregator   */  private val allStatsSize: Int = featureOffsets.last  /**   * Flat array of elements.   * Index for start of stats for a (feature, bin) is:   *   index = featureOffsets(featureIndex) + binIndex * statsSize   * Note: For unordered features,   *       the left child stats have binIndex in [0, numBins(featureIndex) / 2))   *       and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))   */  private val allStats: Array[Double] = new Array[Double](allStatsSize)

每个node有一个DTStatsAggregator,构造函数接受2个参数,metadata和node使用的特征子集。其他的类成员
- impurityAggregator:目前支持Gini,Entropy和Variance,后面我们以Entropy为例,其他类似
- statsSize:每个bin需要的统计数,分类时等于numClasses,因为于每个class都需要单独统计;回归等于3,分别存着特征值个数,特征值sum,特征值平方和,为计算variance
- numBins:node所用特征对应的numBins数组元素
- featureOffsets:计算特征在allStats中的index,与每个特征的bin个数和statsSize有关,例如我们有3个特征,其bins分别为3,2,2,statsSize为2,则第一个特征需要的bin的个数是3 * 2=6,2 * 2=4,2 * 2=4,则featureOffsets为0,6,10,14,是从左到右的累计值
- allStatsSize:需要的桶的个数
- allStats:存储统计值的桶
这里写图片描述
f0,f1,f2是3个特征,f0有3个特征值(其实是binIndex)0/1/2,f1有2个0/1,f2有2个0/1,每个特征值都有statsSize个状态桶,因此共14个,个数allStatsSize=14, 比如我们想在f1的v1的c1的index,就是从featureOffsets中取得f1的特征偏移量featureOffsets(1)=6,v1的binIndex相当于是1,statsSize是2,其label是1,则桶的index=6+1*2+1=9,恰好是图中f1v1的c1的桶的index

我们对其中的关键函数进行说明

/**   * Update the stats for a given (feature, bin) for ordered features, using the given label.   */  def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = {  //第一部分是特征偏移  //binIndex相当于特征内特征值的偏移,每个特征有statsSize个桶,因此两者相加就是这个特征值对应的桶  //例如Entropy的update函数,里面再加上label.toInt就是这个label的桶  //从这里特征偏移的计算可以看出ordered特征其特征值最好是连续的,中间无间断,并且必须从0开始  //当然如果有间断,这里相当于浪费部分空间    val i = featureOffsets(featureIndex) + binIndex * statsSize    impurityAggregator.update(allStats, i, label, instanceWeight)  }  /**   * Get an [[ImpurityCalculator]] for a given (node, feature, bin).   * @param featureOffset  For ordered features, this is a pre-computed (node, feature) offset   *                           from [[getFeatureOffset]].   *                           For unordered features, this is a pre-computed   *                           (node, feature, left/right child) offset from   *                           [[getLeftRightFeatureOffsets]].   */  def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {  //偏移的计算同上,不过这里特征偏移是入参给出的,不需要再计算    impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)  }

6.2. 训练初始化

// FIFO queue of nodes to train: (treeIndex, node)val nodeQueue = new mutable.Queue[(Int, Node)]()val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))    Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

构造了numTrees个Node,赋默认值emptyNode,这些node将作为每棵树的root node,参与后面的训练。将这些node与treeIndex封装加入到队列nodeQueue中,后面会将所有待split的node都加入到这个队列中,依次split,直到所有node触发截止条件,也就是后面的while循环中队列为空了。

6.3. 选择待分裂node

这部分逻辑在selectNodesToSplit中,主要是从nodeQueue中取出本轮需要分裂的node,并计算node的参数。

/**   * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration.   * This tracks the memory usage for aggregates and stops adding nodes when too much memory   * will be needed; this allows an adaptive number of nodes since different nodes may require   * different amounts of memory (if featureSubsetStrategy is not "all").   *   * @param nodeQueue  Queue of nodes to split.   * @param maxMemoryUsage  Bound on size of aggregate statistics.   * @return  (nodesForGroup, treeToNodeToIndexInfo).   *          nodesForGroup holds the nodes to split: treeIndex --> nodes in tree.   *   *          treeToNodeToIndexInfo holds indices selected features for each node:   *            treeIndex --> (global) node index --> (node index in group, feature indices).   *          The (global) node index is the index in the tree; the node index in group is the   *           index in [0, numNodesInGroup) of the node in this group.   *          The feature indices are None if not subsampling features.   */  private[tree] def selectNodesToSplit(      nodeQueue: mutable.Queue[(Int, Node)],      maxMemoryUsage: Long,      metadata: DecisionTreeMetadata,      rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = {    // Collect some nodes to split:    //  nodesForGroup(treeIndex) = nodes to split    val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]()    val mutableTreeToNodeToIndexInfo =      new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]()    var memUsage: Long = 0L    var numNodesInGroup = 0    while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) {      val (treeIndex, node) = nodeQueue.head      //用蓄水池抽样(之前的文章有介绍)对node使用的特征集抽样      // Choose subset of features for node (if subsampling).      val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {        Some(SamplingUtils.reservoirSampleAndCount(Range(0,          metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1)      } else {        None      }      // Check if enough memory remains to add this node to the group.      val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L      if (memUsage + nodeMemUsage <= maxMemoryUsage) {        nodeQueue.dequeue()        mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node        mutableTreeToNodeToIndexInfo          .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id)          = new NodeIndexInfo(numNodesInGroup, featureSubset)      }      numNodesInGroup += 1      memUsage += nodeMemUsage    }    // Convert mutable maps to immutable ones.    val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap    val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap    (nodesForGroup, treeToNodeToIndexInfo)  }

代码比较简单明确,受限于内存,将本次能够处理的node从nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。
是否对特征集进行抽样的条件是metadata的 numFeatures是否等于numFeaturesPerNode,这两个参数是metadata的入参,在buildMetadata时,根据featureSubsetStrateg确定,参见前文。
nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node数组,其中放着该tree本次要分裂的node。
treeToNodeToIndexInfo的类型是Map[Int, Map[Int, NodeIndexInfo]],key为treeIndex,value中Map的key是node.id,这个id来自Node初始化时的第一个参数,第一轮时node的id都是1。其value为NodeIndexInfo结构,

class NodeIndexInfo(      val nodeIndexInGroup: Int,      val featureSubset: Option[Array[Int]])

第一个成员是此node在本次node选择的while循环中的index,称为groupIndex,第二个成员是特征子集。

0 0
原创粉丝点击