关于spark的mllib学习总结(Java版)

来源:互联网 发布:遗传算法java实现 编辑:程序博客网 时间:2024/04/20 09:14
本篇博客主要讲述如何利用Spark的mliib构建机器学习模型并预测新的数据,具体的流程如下图所示: 

基本流程

加载数据

对于数据的加载或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的数据是采用spark中提供的数据sample_libsvm_data.txt,其有一百个数据样本,658个特征。具体的数据形式如图所示: 
数据格式

加载libsvm

JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();
  • 1
  • 1

LabeledPoint数据类型是对应与libsvmfile格式文件, 具体格式为: 
Lable(double类型),vector(Vector类型)

转化dataFrame数据类型

JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow());StructType schema = new StructType(new StructField[]{                    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),                    new StructField("features", new VectorUDT(), false, Metadata.empty()),        });SQLContext jsql = new SQLContext(sc);DataFrame df = jsql.createDataFrame(jrow, schema);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

DataFrame:DataFrame是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者1个Python(或者R)中的data frame一样,但是比他们更优化。DataFrame可以根据结构化的数据文件、Hive表、外部数据库或者已经存在的RDD构造。

SQLContext:spark sql所有功能的入口是SQLContext类,或者SQLContext的子类。为了创建一个基本的SQLContext,需要一个SparkContext。

特征提取

特征归一化处理

StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);DataFrame scalerDF = scaler.fit(df).transform(df);scaler.save(this.scalerModelPath);
  • 1
  • 2
  • 3
  • 1
  • 2
  • 3

利用卡方统计做特征提取

ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");ChiSqSelectorModel chiModel = selector.fit(scalerDF);DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");chiModel.save(this.featureSelectedModelPath);
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

训练机器学习模型(以SVM为例)

//转化为LabeledPoint数据类型, 训练模型JavaRDD<Row> selectedrows = selectedDF.javaRDD();JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());//训练SVM模型, 并保存int numIteration = 200;SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);model.clearThreshold();model.save(sc, this.mlModelPath);// LabeledPoint数据类型转化为Rowstatic class LabeledPointToRow implements Function<LabeledPoint, Row> {        public Row call(LabeledPoint p) throws Exception {            double label = p.label();            Vector vector = p.features();            return RowFactory.create(label, vector);        }    }//Rows数据类型转化为LabeledPointstatic class RowToLabel implements Function<Row, LabeledPoint> {        public LabeledPoint call(Row r) throws Exception {            Vector features = r.getAs(1);            double label = r.getDouble(0);            return new LabeledPoint(label, features);        }    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

测试新的样本

测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:

//初始化sparkSparkConf conf = new SparkConf().setAppName("SVM").setMaster("local");conf.set("spark.testing.memory", "2147480000");SparkContext sc = new SparkContext(conf);//加载测试数据JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();//转化DataFrame数据类型JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());        StructType schema = new StructType(new StructField[]{                    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),                    new StructField("features", new VectorUDT(), false, Metadata.empty()),        });SQLContext jsql = new SQLContext(sc);DataFrame df = jsql.createDataFrame(jrow, schema);        //数据规范化StandardScaler scaler = StandardScaler.load(this.scalerModelPath);DataFrame scalerDF = scaler.fit(df).transform(df);        //特征选取ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath);DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

测试数据集

SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ;predictResult.collect();static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {        SVMModel model;        public Prediction(SVMModel model){            this.model = model;        }        public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {            Double score = model.predict(p.features());            return new Tuple2<Double , Double>(score, p.label());        }    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

计算准确率

double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();System.out.println(accuracy);static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {        public Boolean call(Tuple2<Double, Double> t) throws Exception {            double score = t._1();            double label = t._2();            System.out.print("score:" + score + ", label:"+ label);            if(score >= 0.0 && label >= 0.0) return true;            else if(score < 0.0 && label < 0.0) return true;            else return false;        }    }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

具体的代码,放在我的github上:https://github.com/Quincy1994/MachineLearning/

3
0
 
 

我的同类文章

  • 利用spark做文本聚类分析2017-02-07
  • 在ubuntu中使用java版的spark2016-08-25
  • 利用java的spark做高斯混合模型聚类2017-02-06

参考知识库

img

机器学习知识库

img

Python知识库

img

Hive知识库

img

人工智能机器学习知识库

img

软件测试知识库

img

MySQL知识库

img

Apache Spark知识库

猜你在找
8小时学会HTML网页开发
Android入门实战教程
Swift视频教程(第三季)
Swift视频教程(第七季)
Swift视频教程(第六季)
多层感知机MLP算法原理及Spark MLlib调用实例ScalaJavaPython
梯度迭代树GBDT算法原理及Spark MLlib调用实例ScalaJavapython
Pipeline详解及Spark MLlib使用示例ScalaJavaPython
spark MLlib 学习
spark mllib机器学习之五 LinearRegressionWithSGD
0 0
原创粉丝点击