Spark DecisionTreeModel print

来源:互联网 发布:手环源码 编辑:程序博客网 时间:2024/06/05 05:53


   Spark:1.6.1 ; 


在进行Spark DecisionTree建模时(做分类),可以打印决策树。当然,使用该模型的toDebugString 可以打印类似下面的字符串,例如:
DecisionTreeModel classifier of depth 7 with 45 nodes  If (feature 22 <= 114.2)   If (feature 27 <= 0.1108)    If (feature 13 <= 45.19)     If (feature 21 <= 32.84)      Predict: 0.0     Else (feature 21 > 32.84)      If (feature 1 <= 22.61)       If (feature 0 <= 11.49)        Predict: 0.0       Else (feature 0 > 11.49)        Predict: 1.0      Else (feature 1 > 22.61)       Predict: 0.0    Else (feature 13 > 45.19)     If (feature 21 <= 22.13)      Predict: 0.0     Else (feature 21 > 22.13)      If (feature 14 <= 0.004571)       Predict: 0.0      Else (feature 14 > 0.004571)       Predict: 1.0   Else (feature 27 > 0.1108)    If (feature 21 <= 25.72)     If (feature 24 <= 0.1786)      If (feature 23 <= 809.7)       Predict: 0.0      Else (feature 23 > 809.7)       If (feature 0 <= 14.02)        Predict: 1.0       Else (feature 0 > 14.02)        Predict: 0.0     Else (feature 24 > 0.1786)      Predict: 1.0    Else (feature 21 > 25.72)     If (feature 7 <= 0.05266)      If (feature 20 <= 15.5)       Predict: 0.0      Else (feature 20 > 15.5)       If (feature 4 <= 0.09073)        If (feature 10 <= 0.2406)         Predict: 1.0        Else (feature 10 > 0.2406)         Predict: 0.0       Else (feature 4 > 0.09073)        Predict: 1.0     Else (feature 7 > 0.05266)      If (feature 12 <= 1.539)       Predict: 0.0      Else (feature 12 > 1.539)       Predict: 1.0  Else (feature 22 > 114.2)   If (feature 27 <= 0.1397)    If (feature 1 <= 14.96)     Predict: 0.0    Else (feature 1 > 14.96)     If (feature 20 <= 18.79)      If (feature 17 <= 0.009753)       Predict: 1.0      Else (feature 17 > 0.009753)       If (feature 0 <= 17.3)        Predict: 0.0       Else (feature 0 > 17.3)        Predict: 1.0     Else (feature 20 > 18.79)      Predict: 1.0   Else (feature 27 > 0.1397)    Predict: 1.0



import org.apache.spark.mllib.tree.model.DecisionTreeModelimport org.apache.spark.mllib.tree.model.Node/** * 打印工具(可视化工具). */object PrintUtils {  /**   * 打印决策树   * @param model   * @return   */  def printDecisionTree(model : DecisionTreeModel):String = {    model.toString() + "\n" +    printTree(model.topNode)  }  def printTree(root : Node) :String =  {    val right:String = if (root.rightNode  != None) {      printTree(root.rightNode.get, true, "")    }else {      ""    }    val rootStr = printNodeValue(root)      val left :String= if (root.leftNode != None) {      printTree(root.leftNode.get, false, "")    }else {        ""      }    right + rootStr + left  }  def printNodeValue(root :Node) :String= {    val rootStr :String = if (root.split  == None) {      if(root.isLeaf){        root.predict.toString()      }else{        ""      }    } else {      "Feature:"+root.split.get.feature+" > "+root.split.get.threshold    }    rootStr + "\n"  } def printTree(root : Node,  isRight:Boolean ,  indent:String):String= {    val right:String = if (root.rightNode != None) {      printTree(root.rightNode.get, true, indent + (if(isRight)  "        " else " |      "))    } else {      ""    }//    indent    val right2 = if (isRight) {      " /"    } else {      " \\"    }    val tmp = "----- "    val rootStr = printNodeValue(root)      val left:String =     if (root.leftNode != None) {      printTree(root.leftNode.get, false, indent + (if(isRight)  " |      " else "        "))    }else {        ""      }    right + indent + right2 + tmp + rootStr + left  }}

val modelPath = "..."    val model = DecisionTreeModel.load(sc, modelPath)        println(model.toDebugString)    val str = PrintUtils.printDecisionTree(model)    println(str)

DecisionTreeModel classifier of depth 7 with 45 nodes         /----- 1.0 (prob = 1.0) /----- Feature:27 > 0.1397 |       |               /----- 1.0 (prob = 1.0) |       |       /----- Feature:20 > 18.79 |       |       |       |               /----- 1.0 (prob = 1.0) |       |       |       |       /----- Feature:0 > 17.3 |       |       |       |       |       \----- 0.0 (prob = 1.0) |       |       |       \----- Feature:17 > 0.009753 |       |       |               \----- 1.0 (prob = 1.0) |       \----- Feature:1 > 14.96 |               \----- 0.0 (prob = 1.0)Feature:22 > 114.2 |                               /----- 1.0 (prob = 1.0) |                       /----- Feature:12 > 1.539 |                       |       \----- 0.0 (prob = 1.0) |               /----- Feature:7 > 0.05266 |               |       |               /----- 1.0 (prob = 1.0) |               |       |       /----- Feature:4 > 0.09073 |               |       |       |       |       /----- 0.0 (prob = 1.0) |               |       |       |       \----- Feature:10 > 0.2406 |               |       |       |               \----- 1.0 (prob = 1.0) |               |       \----- Feature:20 > 15.5 |               |               \----- 0.0 (prob = 1.0) |       /----- Feature:21 > 25.72 |       |       |       /----- 1.0 (prob = 1.0) |       |       \----- Feature:24 > 0.1786 |       |               |               /----- 0.0 (prob = 1.0) |       |               |       /----- Feature:0 > 14.02 |       |               |       |       \----- 1.0 (prob = 1.0) |       |               \----- Feature:23 > 809.7 |       |                       \----- 0.0 (prob = 1.0) \----- Feature:27 > 0.1108         |                       /----- 1.0 (prob = 1.0)         |               /----- Feature:14 > 0.004571         |               |       \----- 0.0 (prob = 1.0)         |       /----- Feature:21 > 22.13         |       |       \----- 0.0 (prob = 1.0)         \----- Feature:13 > 45.19                 |               /----- 0.0 (prob = 1.0)                 |       /----- Feature:1 > 22.61                 |       |       |       /----- 1.0 (prob = 1.0)                 |       |       \----- Feature:0 > 11.49                 |       |               \----- 0.0 (prob = 1.0)                 \----- Feature:21 > 32.84                         \----- 0.0 (prob = 1.0)



1. DecisionTreeModel里面定义的Node是可以是离散型的;
2. 在进行建模时,train方法里面的使用的是RDD[LabeledPoint]:

3. 猜测,应该是后面的Spark版本会有对应的支持;




2 0