用Scala调用MLLib之二元分类

来源:互联网 发布:孚盟软件logo 编辑:程序博客网 时间:2024/04/29 23:42

下面的代码段演示了如何导入一份样本数据集,使用算法对象中的静态方法在训练集上执行训练算法,在所得的模型上进行预测并计算训练误差。

import org.apache.spark.SparkContextimport org.apache.spark.mllib.classification.SVMWithSGDimport org.apache.spark.mllib.regression.LabeledPoint// Load and parse the data fileval data = sc.textFile("mllib/data/sample_svm_data.txt")val parsedData = data.map { line =>val parts = line.split(' ')LabeledPoint(parts(0).toDouble, parts.tail.map(x => x.toDouble).toArray)}// Run training algorithm to build the modelval numIterations = 20val model = SVMWithSGD.train(parsedData, numIterations)// Evaluate model on training examples and compute training errorval labelAndPreds = parsedData.map { point =>val prediction = model.predict(point.features)(point.label, prediction)}val trainErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / parsedData.countprintln("Training Error = " + trainErr)

默认情况下,这个SVMWithSGD.train()方法使用正则参数为 1.0 的 L2 正则项。如果我们想配置这个算法,我们可以通过直接新建一个新的对象,并调用setter的方法,进一步个性化设置SVMWithSGD。所有其他的 MLlib 算法也是通过这样的方法来支持个性化的设置。比如,下面的代码给出了一个正则参数为0.1的 L1 正则化SVM变体,并且让这个训练算法迭代200遍。

import org.apache.spark.mllib.optimization.L1Updaterval svmAlg = new SVMWithSGD()svmAlg.optimizer.setNumIterations(200).setRegParam(0.1).setUpdater(new L1Updater)val modelL1 = svmAlg.run(parsedData)
0 0
原创粉丝点击