spark mllib源码分析之随机森林(Random Forest)(五)

来源:互联网 发布:淘宝基金欠款 编辑:程序博客网 时间:2024/06/05 20:52

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)

7. 构造随机森林

在上面的训练过程可以看到,从根节点topNode中不断向下分裂一直到触发截止条件就构造了一棵树所有的node,因此构造整个森林也是非常简单

//构造val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))//返回rf模型new RandomForestModel(strategy.algo, trees)

8. 随机森林模型

8.1. TreeEnsembleModel

随机森林RandomForestModel继承自树集合模型TreeEnsembleModel

class TreeEnsembleModel(    protected val algo: Algo,    protected val trees: Array[DecisionTreeModel],    protected val treeWeights: Array[Double],    protected val combiningStrategy: EnsembleCombiningStrategy)
  • algo:Regression/Classification
  • trees:树数组
  • treeWeights:每棵树的权重,在RF中每棵树的权重是相同的,在Adaboost可能是不同的
  • combiningStrategy:树合并时的策略,Sum/Average/Vote,分类的话应该是Vote,RF应该是Average,GBDT应该是Sum。
  • sumWeights:成员变量,不在参数表中,是treeWeights的sum

预测函数

/**   * Predicts for a single data point using the weighted sum of ensemble predictions.   *   * @param features array representing a single data point   * @return predicted category from the trained model   */  private def predictBySumming(features: Vector): Double = {    val treePredictions = trees.map(_.predict(features))    blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)  }

将每棵树的预测结果与各自的weight向量相乘

/**   * Classifies a single data point based on (weighted) majority votes.   */  private def predictByVoting(features: Vector): Double = {    val votes = mutable.Map.empty[Int, Double]    trees.view.zip(treeWeights).foreach { case (tree, weight) =>      val prediction = tree.predict(features).toInt      votes(prediction) = votes.getOrElse(prediction, 0.0) + weight    }    votes.maxBy(_._2)._1  }

将每棵树的预测class为key,将树的weight累加到Map中作为value,最后取权重和最大对应的class

8.2. RandomForestModel

RandomForestModel @Since("1.2.0") (    @Since("1.2.0") override val algo: Algo,    @Since("1.2.0") override val trees: Array[DecisionTreeModel])  extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),    combiningStrategy = if (algo == Classification) Vote else Average)

对于随机森林,其weight都是1,树合并策略如果是分类就是Vote,回归是Average。
模型生成后,如果要应用到线上,需要将训练后的模型保存下来,自己写代码解析模型文件,进行预测,因此要了解模型的保存和加载。

8.2.1. 模型保存

分为两部分,第一部分是metadata,保存了一些配置,包括模型名,模型版本,模型的algo是classification/regression,合并策略,每棵树的权重。

implicit val format = DefaultFormatsval ensembleMetadata = Metadata(model.algo.toString,    model.trees(0).algo.toString,    model.combiningStrategy.toString,     model.treeWeights)val metadata = compact(render(    ("class" -> className) ~ ("version" -> thisFormatVersion) ~    ("metadata" -> Extraction.decompose(ensembleMetadata))))sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

第二部分是随机森林的每棵树的保存

// Create Parquet data.val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>    tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))}.toDF()dataRDD.write.parquet(Loader.dataPath(path))

其中首先调用node的subtreeIterator函数,返回所有node的Iterator,然后转成DataFrame结构,再写成parquet格式的文件。我们来看subtreeIterator函数

/** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */  private[tree] def subtreeIterator: Iterator[Node] = {    Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++      rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)  }

其实就是用前序遍历的方式返回了树中的所有node的Iterrator。
我们再来看NodeData,看看每个node保存了什么数据

def apply(treeId: Int, n: Node): NodeData = {    NodeData(treeId, n.id, PredictData(n.predict), n.impurity,    n.isLeaf, n.split.map(SplitData.apply), n.leftNode.map(_.id),     n.rightNode.map(_.id), n.stats.map(_.gain))}

保存了node的预测值,impurity,是否是否叶子节点,split,左右孩子节点的id,gain。其中split中包含了特征id,特征阈值,特征类型,离散特征数组(其实就是Split结构)。

8.2.2. 模型加载

metadata的加载就是解析json,主要是树的重建

val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc,     path, metadata.treeAlgo)new RandomForestModel(Algo.fromString(metadata.algo), trees)

其中调用了loadTrees函数

/** * Load trees for an ensemble, and return them in order. * @param path path to load the model from * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's *                 algorithm). */def loadTrees(        sc: SparkContext,        path: String,        treeAlgo: String): Array[DecisionTreeModel] = {    val datapath = Loader.dataPath(path)    val sqlContext = SQLContext.getOrCreate(sc)    val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply)    val trees = constructTrees(nodes)    trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))}

先是读取数据文件,读成NodeData格式,然后调用constructTrees重建树结构

    def constructTrees(nodes: RDD[NodeData]): Array[Node] = {      val trees = nodes        .groupBy(_.treeId)        .mapValues(_.toArray)        .collect()        .map { case (treeId, data) =>          (treeId, constructTree(data))        }.sortBy(_._1)      val numTrees = trees.size      val treeIndices = trees.map(_._1).toSeq      assert(treeIndices == (0 until numTrees),        s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")      trees.map(_._2)    }

主要功能按树的id分组后,调用constructTree重建树

    /**     * Given a list of nodes from a tree, construct the tree.     * @param data array of all node data in a tree.     */    def constructTree(data: Array[NodeData]): Node = {      val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap      assert(dataMap.contains(1),        s"DecisionTree missing root node (id = 1).")      constructNode(1, dataMap, mutable.Map.empty)    }    /**     * Builds a node from the node data map and adds new nodes to the input nodes map.     */    private def constructNode(      id: Int,      dataMap: Map[Int, NodeData],      nodes: mutable.Map[Int, Node]): Node = {      if (nodes.contains(id)) {        return nodes(id)      }      val data = dataMap(id)      val node =        if (data.isLeaf) {          Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)        } else {          val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)          val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)          val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,            rightNode.impurity, leftNode.predict, rightNode.predict)          new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,            data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))        }      nodes += node.id -> node      node    }

其实就是递归的从NodeData中获取数据,重建node

从上面的分析可以看到,spark保存模型使用了parquet格式,对于我们在别的环境中使用是非常不方便的,训练完模型后,我们可以参照spark的做法,按照前序遍历的方法以json的格式保存node,在别的环境下复建树结构就可以了。

9. 坑

  • 特征id,样本是libsvm格式的,特征id从1开始,但是设置离散特征数categoricalFeaturesInfo需要从0开始,相当于样本特征id-1
  • 离散特征值,一旦在categoricalFeaturesInfo中指定了特征值的个数k,spark会认为这个特征是从0开始,连续到k-1。如果其中特征不连续,特征数应该设置成最大特征值+1
  • 对于连续特征,spark使用等频离散化方法,又对样本进行了抽样,效果其实很难保证,不知道作者是否比较过这种方法与等间隔离散化效果孰优孰劣
  • maxBins的设置需要考虑连续特征离散化效果,连续特征离散化值的个数是maxBins-1,同时maxBins必须大于categoricalFeaturesInfo中最大离散特征值的个数
  • ordered feature,之前的理解是有误的,这里的order仅仅是说这种特征是可以经过某种方式排列后变成有序,排序标准根据分类/回归而不同,在上面的文章有具体介绍。在我们的实践中,有的离散特征,例如薪资,1代表0-1000元,2代表1000-2000元,3代表2000-3000元,特征值的大小本身就表征了实际意义,这种应该直接按连续特征处理(当然也可以对比下效果决定)。

10. 结语

我们基本上是逐行分析了spark随机森林的实现,展现了其实现过程中使用的技巧,希望对大家在理解随机森林和其实现方法有所帮助。

原创粉丝点击