Spark2 机器学习之决策树分类Decision tree classifier
来源:互联网 发布:java设计登录窗口 编辑:程序博客网 时间:2024/04/30 14:54
分类决策树代码
import org.apache.spark.sql.SparkSessionimport org.apache.spark.sql.Datasetimport org.apache.spark.sql.Rowimport org.apache.spark.sql.DataFrameimport org.apache.spark.sql.Columnimport org.apache.spark.sql.DataFrameReaderimport org.apache.spark.rdd.RDDimport org.apache.spark.sql.catalyst.encoders.ExpressionEncoderimport org.apache.spark.sql.Encoderimport org.apache.spark.sql.DataFrameStatFunctionsimport org.apache.spark.sql.functions._import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.DecisionTreeClassificationModelimport org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.feature.{ VectorAssembler, IndexToString, StringIndexer, VectorIndexer }val spark = SparkSession.builder().appName("Spark decision tree classifier").config("spark.some.config.option", "some-value").getOrCreate()// For implicit conversions like converting RDDs to DataFramesimport spark.implicits._// 这里仅仅是示例数据,完整的数据源,请参考我的博客http://blog.csdn.net/hadoop_spark_storm/article/details/53412598val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( (0, "male", 37, 10, "no", 3, 18, 7, 4), (0, "female", 27, 4, "no", 4, 14, 6, 4), (0, "female", 32, 15, "yes", 1, 12, 1, 4), (0, "male", 57, 15, "yes", 5, 18, 6, 5), (0, "male", 22, 0.75, "no", 2, 17, 6, 3), (0, "female", 32, 1.5, "no", 2, 17, 5, 5))val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")data.createOrReplaceTempView("data")// 字符类型转换成数值val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children"val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")// 字段转换成特征向量val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")val vecDF: DataFrame = assembler.transform(dataLabelDF)vecDF.show(10,truncate=false)// 索引标签,将元数据添加到标签列中val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)labelIndexer.transform(vecDF).show(10,truncate=false)// 自动识别分类的特征,并对它们进行索引// 具有大于8个不同的值的特征被视为连续。val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(8).fit(vecDF)featureIndexer.transform(vecDF).show(10,truncate=false)// 将数据分为训练和测试集(30%进行测试)val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3))// 训练决策树模型val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setImpurity("entropy") // 不纯度.setMaxBins(100) // 离散化"连续特征"的最大划分数.setMaxDepth(5) // 树的最大深度.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1].setMinInstancesPerNode(10) //每个节点包含的最小样本数 .setSeed(123456)// 将索引标签转换回原始标签val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)// Chain indexers and tree in a Pipeline.val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))// Train model. This also runs the indexers.val model = pipeline.fit(trainingData)// 作出预测val predictions = model.transform(testData)// 选择几个示例行展示predictions.select("predictedLabel", "label", "features").show(10,truncate=false)// 选择(预测标签,实际标签),并计算测试误差。val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println("Test Error = " + (1.0 - accuracy))// 这里的stages(2)中的“2”对应pipeline中的“dt”,将model强制转换为DecisionTreeClassificationModel类型val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]treeModel.getLabelColtreeModel.getFeaturesColtreeModel.featureImportancestreeModel.getPredictionColtreeModel.getProbabilityColtreeModel.numClassestreeModel.numFeaturestreeModel.depthtreeModel.numNodestreeModel.getImpuritytreeModel.getMaxBinstreeModel.getMaxDepthtreeModel.getMaxMemoryInMBtreeModel.getMinInfoGaintreeModel.getMinInstancesPerNode // 查看决策树println("Learned classification tree model:\n" + treeModel.toDebugString)
代码执行结果
import org.apache.spark.sql.SparkSessionimport org.apache.spark.sql.Datasetimport org.apache.spark.sql.Rowimport org.apache.spark.sql.DataFrameimport org.apache.spark.sql.Columnimport org.apache.spark.sql.DataFrameReaderimport org.apache.spark.rdd.RDDimport org.apache.spark.sql.catalyst.encoders.ExpressionEncoderimport org.apache.spark.sql.Encoderimport org.apache.spark.sql.DataFrameStatFunctionsimport org.apache.spark.sql.functions._import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.DecisionTreeClassificationModelimport org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.feature.{ VectorAssembler, IndexToString, StringIndexer, VectorIndexer }val spark = SparkSession.builder().appName("Spark decision tree classifier").config("spark.some.config.option", "some-value").getOrCreate()// For implicit conversions like converting RDDs to DataFramesimport spark.implicits._// 这里仅仅是示例数据,完整的数据源,请参考我的博客http://blog.csdn.net/hadoop_spark_storm/article/details/53412598val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( (0, "male", 37, 10, "no", 3, 18, 7, 4), (0, "female", 27, 4, "no", 4, 14, 6, 4), (0, "female", 32, 15, "yes", 1, 12, 1, 4), (0, "male", 57, 15, "yes", 5, 18, 6, 5), (0, "male", 22, 0.75, "no", 2, 17, 6, 3), (0, "female", 32, 1.5, "no", 2, 17, 5, 5))val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") data: org.apache.spark.sql.DataFrame = [affairs: double, gender: string ... 7 more fields]data.printSchema() root |-- affairs: double (nullable = false) |-- gender: string (nullable = true) |-- age: double (nullable = false) |-- yearsmarried: double (nullable = false) |-- children: string (nullable = true) |-- religiousness: double (nullable = false) |-- education: double (nullable = false) |-- occupation: double (nullable = false) |-- rating: double (nullable = false) data.show(10,truncate=false)+-------+------+----+------------+--------+-------------+---------+----------+------+|affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating|+-------+------+----+------------+--------+-------------+---------+----------+------+|0.0 |male |37.0|10.0 |no |3.0 |18.0 |7.0 |4.0 ||0.0 |female|27.0|4.0 |no |4.0 |14.0 |6.0 |4.0 ||0.0 |female|32.0|15.0 |yes |1.0 |12.0 |1.0 |4.0 ||0.0 |male |57.0|15.0 |yes |5.0 |18.0 |6.0 |5.0 ||0.0 |male |22.0|0.75 |no |2.0 |17.0 |6.0 |3.0 ||0.0 |female|32.0|1.5 |no |2.0 |17.0 |5.0 |5.0 ||0.0 |female|22.0|0.75 |no |2.0 |12.0 |1.0 |3.0 ||0.0 |male |57.0|15.0 |yes |2.0 |14.0 |4.0 |4.0 ||0.0 |female|32.0|15.0 |yes |4.0 |16.0 |1.0 |2.0 ||0.0 |male |22.0|1.5 |no |4.0 |14.0 |4.0 |5.0 |+-------+------+----+------------+--------+-------------+---------+----------+------+only showing top 10 rows// 查看数据分布情况data.describe("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating").show(10,truncate=false)+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+|summary|affairs |gender|age |yearsmarried |children|religiousness |education |occupation |rating |+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+|count |601 |601 |601 |601 |601 |601 |601 |601 |601 ||mean |1.4559068219633944|null |32.48752079866888|8.17769550748752 |null |3.1164725457570714|16.16638935108153|4.194675540765391|3.9317803660565724||stddev |3.298757728494681 |null |9.28876170487667 |5.571303149963791|null |1.1675094016730692|2.402554565766698|1.819442662708579|1.1031794920503795||min |0.0 |female|17.5 |0.125 |no |1.0 |9.0 |1.0 |1.0 ||max |12.0 |male |57.0 |15.0 |yes |5.0 |20.0 |7.0 |5.0 |+-------+------------------+------+-----------------+-----------------+--------+------------------+-----------------+-----------------+------------------+data.createOrReplaceTempView("data")// 字符类型转换成数值val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"labelWhere: String = case when affairs=0 then 0 else cast(1 as double) end as labelval genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"genderWhere: String = case when gender='female' then 0 else cast(1 as double) end as genderval childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children"childrenWhere: String = case when children='no' then 0 else cast(1 as double) end as childrenval dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")dataLabelDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 7 more fields]val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")featuresArray: Array[String] = Array(gender, age, yearsmarried, children, religiousness, education, occupation, rating)// 字段转换成特征向量val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_6e2c6bdd631eval vecDF: DataFrame = assembler.transform(dataLabelDF)vecDF: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 8 more fields]vecDF.show(10,truncate=false)+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+|0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]||0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] ||0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]||0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]||0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]||0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] ||0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]||0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]||0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]||0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+only showing top 10 rows// 索引标签,将元数据添加到标签列中val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_d00cad619cd5labelIndexer.transform(vecDF).show(10,truncate=false)+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedLabel|+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+|0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|0.0 ||0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |0.0 ||0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|0.0 ||0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|0.0 ||0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|0.0 ||0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |0.0 ||0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|0.0 ||0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|0.0 ||0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|0.0 ||0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |0.0 |+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+only showing top 10 rows// 自动识别分类的特征,并对它们进行索引// 具有大于8个不同的值的特征被视为连续。val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(8).fit(vecDF)featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_8fbcad97fb60featureIndexer.transform(vecDF).show(10,truncate=false)+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features |indexedFeatures |+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+|0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|[1.0,37.0,6.0,0.0,2.0,5.0,6.0,3.0]||0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |[0.0,27.0,4.0,0.0,3.0,2.0,5.0,3.0]||0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|[0.0,32.0,7.0,1.0,0.0,1.0,0.0,3.0]||0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|[1.0,57.0,7.0,1.0,4.0,5.0,5.0,4.0]||0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|[1.0,22.0,2.0,0.0,1.0,4.0,5.0,2.0]||0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |[0.0,32.0,3.0,0.0,1.0,4.0,4.0,4.0]||0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|[0.0,22.0,2.0,0.0,1.0,1.0,0.0,2.0]||0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|[1.0,57.0,7.0,1.0,1.0,2.0,3.0,3.0]||0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|[0.0,32.0,7.0,1.0,3.0,3.0,0.0,1.0]||0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |[1.0,22.0,3.0,0.0,3.0,2.0,3.0,4.0]|+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+----------------------------------+only showing top 10 rows// 将数据分为训练和测试集(30%进行测试)val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3))trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields]testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields]// 训练决策树模型val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setImpurity("entropy") // 不纯度.setMaxBins(100) // 离散化"连续特征"的最大划分数.setMaxDepth(5) // 树的最大深度.setMinInfoGain(0.01) //一个节点分裂的最小信息增益,值为[0,1].setMinInstancesPerNode(10) //每个节点包含的最小样本数 .setSeed(123456)// 将索引标签转换回原始标签val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_2598e79a1d08// Chain indexers and tree in a Pipeline.val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))// Train model. This also runs the indexers.val model = pipeline.fit(trainingData)// 作出预测val predictions = model.transform(testData)predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields]// 选择几个示例行展示predictions.select("predictedLabel", "label", "features").show(10,truncate=false)+--------------+-----+-------------------------------------+|predictedLabel|label|features |+--------------+-----+-------------------------------------+|0.0 |0.0 |[0.0,22.0,0.125,0.0,2.0,14.0,4.0,5.0]||0.0 |0.0 |[0.0,22.0,0.125,0.0,2.0,16.0,6.0,3.0]||0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]||0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]||0.0 |0.0 |[0.0,22.0,0.75,0.0,2.0,16.0,5.0,5.0] ||0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] ||0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,14.0,5.0,4.0] ||0.0 |0.0 |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0] ||0.0 |0.0 |[0.0,22.0,1.5,0.0,3.0,16.0,6.0,5.0] ||0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0] |+--------------+-----+-------------------------------------+only showing top 10 rows// 选择(预测标签,实际标签),并计算测试误差。val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)accuracy: Double = 0.7032967032967034println("Test Error = " + (1.0 - accuracy))Test Error = 0.29670329670329665// 这里的stages(2)中的“2”对应pipeline中的“dt”,将model强制转换为DecisionTreeClassificationModel类型val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]treeModel: org.apache.spark.ml.classification.DecisionTreeClassificationModel = DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodestreeModel.getLabelColres53: String = indexedLabeltreeModel.getFeaturesColres54: String = indexedFeaturestreeModel.featureImportancesres55: org.apache.spark.ml.linalg.Vector = (8,[0,2,3,4,5,6,7],[0.0640344247735859,0.1052957011097811,0.05343872372010684,0.17367191628391196,0.20372870264756315,0.2063093687074741,0.1935211627575769])treeModel.getPredictionColres56: String = predictiontreeModel.getProbabilityColres57: String = probabilitytreeModel.numClassesres58: Int = 2treeModel.numFeaturesres59: Int = 8treeModel.depthres60: Int = 5treeModel.numNodesres61: Int = 33treeModel.getImpurityres62: String = entropytreeModel.getMaxBinsres63: Int = 100treeModel.getMaxDepthres64: Int = 5treeModel.getMaxMemoryInMBres65: Int = 256treeModel.getMinInfoGainres66: Double = 0.01treeModel.getMinInstancesPerNoderes67: Int = 10 // 查看决策树println("Learned classification tree model:\n" + treeModel.toDebugString)Learned classification tree model:DecisionTreeClassificationModel (uid=dtc_7a8baf97abe7) of depth 5 with 33 nodes If (feature 2 in {0.0,1.0,2.0,3.0}) If (feature 5 in {3.0,6.0}) Predict: 0.0 Else (feature 5 not in {3.0,6.0}) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 3 in {0.0}) If (feature 6 in {0.0,4.0,5.0}) Predict: 0.0 Else (feature 6 not in {0.0,4.0,5.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 2 not in {0.0,1.0,2.0,3.0}) If (feature 4 in {0.0,1.0,3.0,4.0}) If (feature 7 in {0.0,1.0,2.0}) If (feature 6 in {0.0,1.0,6.0}) If (feature 4 in {1.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,4.0}) Predict: 0.0 Else (feature 6 not in {0.0,1.0,6.0}) If (feature 7 in {0.0,2.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0}) If (feature 5 in {0.0,1.0}) Predict: 0.0 Else (feature 5 not in {0.0,1.0}) If (feature 6 in {0.0,1.0,2.0,5.0,6.0}) Predict: 0.0 Else (feature 6 not in {0.0,1.0,2.0,5.0,6.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0,4.0}) If (feature 5 in {0.0,1.0,2.0,3.0,5.0,6.0}) If (feature 0 in {0.0}) If (feature 7 in {3.0}) Predict: 0.0 Else (feature 7 not in {3.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 7 in {0.0,2.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) Predict: 1.0 Else (feature 5 not in {0.0,1.0,2.0,3.0,5.0,6.0}) Predict: 1.0
0 0
- Spark2 机器学习之决策树分类Decision tree classifier
- Spark2 ML包之决策树分类Decision tree classifier详细解说
- 【机器学习】分类算法之决策树(Decision tree)
- 机器学习之决策树(Decision Tree)
- 机器学习之:决策树(Decision Tree)
- 机器学习之决策树 Decision Tree(一)
- 机器学习之决策树 Decision Tree(二)Python实现
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 分类算法之决策树(Decision tree)
- 【机器学习】决策树(Decision Tree)
- centos6.8 升级libc
- weblogic 修改 应用上下文的两种方式
- cannot find -lxxx
- 393. UTF-8 Validation
- Caffe学习笔记1:linux下建立自己的数据库训练和测试caffe中已有网络
- Spark2 机器学习之决策树分类Decision tree classifier
- 学习shader之前必须知道的东西之计算机图形学-渲染管线
- 用户实时行为数据采集
- Android热插拔事件处理流程--Vold
- 【android】getCacheDir()、getFilesDir()、getExternalFilesDir()、getExternalCacheDir()的作用
- 操作系统学习笔记2----进程管理
- eclipse创建maven项目,DynamicWebModule默认为2.3修改不了
- poj_2115 C Looooops(模线性方程+扩展欧几里得)
- struts2入门(搭建环境、配置、示例)