spark mllib源码分析之随机森林(Random Forest)(五)
来源:互联网 发布:淘宝基金欠款 编辑:程序博客网 时间:2024/06/05 20:52
spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(四)
7. 构造随机森林
在上面的训练过程可以看到,从根节点topNode中不断向下分裂一直到触发截止条件就构造了一棵树所有的node,因此构造整个森林也是非常简单
//构造val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))//返回rf模型new RandomForestModel(strategy.algo, trees)
8. 随机森林模型
8.1. TreeEnsembleModel
随机森林RandomForestModel继承自树集合模型TreeEnsembleModel
class TreeEnsembleModel( protected val algo: Algo, protected val trees: Array[DecisionTreeModel], protected val treeWeights: Array[Double], protected val combiningStrategy: EnsembleCombiningStrategy)
- algo:Regression/Classification
- trees:树数组
- treeWeights:每棵树的权重,在RF中每棵树的权重是相同的,在Adaboost可能是不同的
- combiningStrategy:树合并时的策略,Sum/Average/Vote,分类的话应该是Vote,RF应该是Average,GBDT应该是Sum。
- sumWeights:成员变量,不在参数表中,是treeWeights的sum
预测函数
/** * Predicts for a single data point using the weighted sum of ensemble predictions. * * @param features array representing a single data point * @return predicted category from the trained model */ private def predictBySumming(features: Vector): Double = { val treePredictions = trees.map(_.predict(features)) blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) }
将每棵树的预测结果与各自的weight向量相乘
/** * Classifies a single data point based on (weighted) majority votes. */ private def predictByVoting(features: Vector): Double = { val votes = mutable.Map.empty[Int, Double] trees.view.zip(treeWeights).foreach { case (tree, weight) => val prediction = tree.predict(features).toInt votes(prediction) = votes.getOrElse(prediction, 0.0) + weight } votes.maxBy(_._2)._1 }
将每棵树的预测class为key,将树的weight累加到Map中作为value,最后取权重和最大对应的class
8.2. RandomForestModel
RandomForestModel @Since("1.2.0") ( @Since("1.2.0") override val algo: Algo, @Since("1.2.0") override val trees: Array[DecisionTreeModel]) extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0), combiningStrategy = if (algo == Classification) Vote else Average)
对于随机森林,其weight都是1,树合并策略如果是分类就是Vote,回归是Average。
模型生成后,如果要应用到线上,需要将训练后的模型保存下来,自己写代码解析模型文件,进行预测,因此要了解模型的保存和加载。
8.2.1. 模型保存
分为两部分,第一部分是metadata,保存了一些配置,包括模型名,模型版本,模型的algo是classification/regression,合并策略,每棵树的权重。
implicit val format = DefaultFormatsval ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, model.combiningStrategy.toString, model.treeWeights)val metadata = compact(render( ("class" -> className) ~ ("version" -> thisFormatVersion) ~ ("metadata" -> Extraction.decompose(ensembleMetadata))))sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
第二部分是随机森林的每棵树的保存
// Create Parquet data.val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))}.toDF()dataRDD.write.parquet(Loader.dataPath(path))
其中首先调用node的subtreeIterator函数,返回所有node的Iterator,然后转成DataFrame结构,再写成parquet格式的文件。我们来看subtreeIterator函数
/** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ private[tree] def subtreeIterator: Iterator[Node] = { Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) }
其实就是用前序遍历的方式返回了树中的所有node的Iterrator。
我们再来看NodeData,看看每个node保存了什么数据
def apply(treeId: Int, n: Node): NodeData = { NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), n.stats.map(_.gain))}
保存了node的预测值,impurity,是否是否叶子节点,split,左右孩子节点的id,gain。其中split中包含了特征id,特征阈值,特征类型,离散特征数组(其实就是Split结构)。
8.2.2. 模型加载
metadata的加载就是解析json,主要是树的重建
val trees = TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)new RandomForestModel(Algo.fromString(metadata.algo), trees)
其中调用了loadTrees函数
/** * Load trees for an ensemble, and return them in order. * @param path path to load the model from * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's * algorithm). */def loadTrees( sc: SparkContext, path: String, treeAlgo: String): Array[DecisionTreeModel] = { val datapath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val nodes = sqlContext.read.parquet(datapath).map(NodeData.apply) val trees = constructTrees(nodes) trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))}
先是读取数据文件,读成NodeData格式,然后调用constructTrees重建树结构
def constructTrees(nodes: RDD[NodeData]): Array[Node] = { val trees = nodes .groupBy(_.treeId) .mapValues(_.toArray) .collect() .map { case (treeId, data) => (treeId, constructTree(data)) }.sortBy(_._1) val numTrees = trees.size val treeIndices = trees.map(_._1).toSeq assert(treeIndices == (0 until numTrees), s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") trees.map(_._2) }
主要功能按树的id分组后,调用constructTree重建树
/** * Given a list of nodes from a tree, construct the tree. * @param data array of all node data in a tree. */ def constructTree(data: Array[NodeData]): Node = { val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap assert(dataMap.contains(1), s"DecisionTree missing root node (id = 1).") constructNode(1, dataMap, mutable.Map.empty) } /** * Builds a node from the node data map and adds new nodes to the input nodes map. */ private def constructNode( id: Int, dataMap: Map[Int, NodeData], nodes: mutable.Map[Int, Node]): Node = { if (nodes.contains(id)) { return nodes(id) } val data = dataMap(id) val node = if (data.isLeaf) { Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) } else { val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, rightNode.impurity, leftNode.predict, rightNode.predict) new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) } nodes += node.id -> node node }
其实就是递归的从NodeData中获取数据,重建node
从上面的分析可以看到,spark保存模型使用了parquet格式,对于我们在别的环境中使用是非常不方便的,训练完模型后,我们可以参照spark的做法,按照前序遍历的方法以json的格式保存node,在别的环境下复建树结构就可以了。
9. 坑
- 特征id,样本是libsvm格式的,特征id从1开始,但是设置离散特征数categoricalFeaturesInfo需要从0开始,相当于样本特征id-1
- 离散特征值,一旦在categoricalFeaturesInfo中指定了特征值的个数k,spark会认为这个特征是从0开始,连续到k-1。如果其中特征不连续,特征数应该设置成最大特征值+1
- 对于连续特征,spark使用等频离散化方法,又对样本进行了抽样,效果其实很难保证,不知道作者是否比较过这种方法与等间隔离散化效果孰优孰劣
- maxBins的设置需要考虑连续特征离散化效果,连续特征离散化值的个数是maxBins-1,同时maxBins必须大于categoricalFeaturesInfo中最大离散特征值的个数
- ordered feature,之前的理解是有误的,这里的order仅仅是说这种特征是可以经过某种方式排列后变成有序,排序标准根据分类/回归而不同,在上面的文章有具体介绍。在我们的实践中,有的离散特征,例如薪资,1代表0-1000元,2代表1000-2000元,3代表2000-3000元,特征值的大小本身就表征了实际意义,这种应该直接按连续特征处理(当然也可以对比下效果决定)。
10. 结语
我们基本上是逐行分析了spark随机森林的实现,展现了其实现过程中使用的技巧,希望对大家在理解随机森林和其实现方法有所帮助。
- spark mllib源码分析之随机森林(Random Forest)(五)
- spark mllib源码分析之随机森林(Random Forest)(一)
- spark mllib源码分析之随机森林(Random Forest)(二)
- spark mllib源码分析之随机森林(Random Forest)(三)
- spark mllib源码分析之随机森林(Random Forest)(四)
- 随机森林(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林回归(Random Forest)算法原理及Spark MLlib调用实例(Scala/Java/python)
- 随机森林(Random Forest)
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林(Random Forest)
- 随机森林Random Forest
- 随机森林--Random Forest
- 随机森林Random Forest
- 随机森林(Random Forest)
- random forest(随机森林)
- 随机森林Random Forest
- Random Forest(随机森林)
- Linux 网络性能测试工具 iperf 的安装和使用
- 计算机视觉领域的一些牛人博客,超有实力的研究机构等的网站链接
- 前端开发-数据分页请求和删除
- JPA动态查询代码封装
- HTTPS的配置
- spark mllib源码分析之随机森林(Random Forest)(五)
- PackBits算法
- linux top命令中各cpu占用率含义及案例分析
- Masonry 自动布局使用案例
- React this.state
- C#之转义符
- codeforces——581A——Vasya the Hipster
- 介绍几种简单的文件加密方法,挺有意思的
- 使用VirtualBOX自带的共享文件夹功能