One-vs-Rest算法介绍及Spark MLlib调用实例(Scala/Java/Python)

来源:互联网 发布:kali linux安装到u盘 编辑:程序博客网 时间:2024/06/06 19:58

One-vs-Rest

算法介绍:

OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做“One-vs-All.”算法。OneVsRest是一个Estimator。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。

参数:

featuresCol:

类型:字符串型。

含义:特征列名。

labelCol:

类型:字符串型。

含义:标签列名。

predictionCol:

类型:字符串型。

含义:预测结果列名。

classifier:

类型:分类器型。

含义:基础二分类分类器。

示例:

Scala:

import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator// load data file.val inputData = spark.read.format("libsvm")  .load("data/mllib/sample_multiclass_classification_data.txt")// generate the train/test split.val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))// instantiate the base classifierval classifier = new LogisticRegression()  .setMaxIter(10)  .setTol(1E-6)  .setFitIntercept(true)// instantiate the One Vs Rest Classifier.val ovr = new OneVsRest().setClassifier(classifier)// train the multiclass model.val ovrModel = ovr.fit(train)// score the model on test data.val predictions = ovrModel.transform(test)// obtain evaluator.val evaluator = new MulticlassClassificationEvaluator()  .setMetricName("accuracy")// compute the classification error on test data.val accuracy = evaluator.evaluate(predictions)println(s"Test Error : ${1 - accuracy}")

Java:

import org.apache.spark.ml.classification.LogisticRegression;import org.apache.spark.ml.classification.OneVsRest;import org.apache.spark.ml.classification.OneVsRestModel;import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;// load data file.Dataset<Row> inputData = spark.read().format("libsvm")  .load("data/mllib/sample_multiclass_classification_data.txt");// generate the train/test split.Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});Dataset<Row> train = tmp[0];Dataset<Row> test = tmp[1];// configure the base classifier.LogisticRegression classifier = new LogisticRegression()  .setMaxIter(10)  .setTol(1E-6)  .setFitIntercept(true);// instantiate the One Vs Rest Classifier.OneVsRest ovr = new OneVsRest().setClassifier(classifier);// train the multiclass model.OneVsRestModel ovrModel = ovr.fit(train);// score the model on test data.Dataset<Row> predictions = ovrModel.transform(test)  .select("prediction", "label");// obtain evaluator.MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()        .setMetricName("accuracy");// compute the classification error on test data.double accuracy = evaluator.evaluate(predictions);System.out.println("Test Error : " + (1 - accuracy));

Python:

from pyspark.ml.classification import LogisticRegression, OneVsRestfrom pyspark.ml.evaluation import MulticlassClassificationEvaluator# load data file.inputData = spark.read.format("libsvm") \    .load("data/mllib/sample_multiclass_classification_data.txt")# generate the train/test split.(train, test) = inputData.randomSplit([0.8, 0.2])# instantiate the base classifier.lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)# instantiate the One Vs Rest Classifier.ovr = OneVsRest(classifier=lr)# train the multiclass model.ovrModel = ovr.fit(train)# score the model on test data.predictions = ovrModel.transform(test)# obtain evaluator.evaluator = MulticlassClassificationEvaluator(metricName="accuracy")# compute the classification error on test data.accuracy = evaluator.evaluate(predictions)print("Test Error : " + str(1 - accuracy))

0 0
原创粉丝点击