【spark】采用LogisticRegression(ML API篇)对MNIST的0-1数字进行识别
来源:互联网 发布:淘宝上下架工具 编辑:程序博客网 时间:2024/06/07 12:08
:ROC曲线概念
http://blog.csdn.net/abcjennifer/article/details/7359370
:Recall-Precision概念
http://blog.csdn.net/pirage/article/details/9851339
:下载MNIST数据集
http://yann.lecun.com/exdb/mnist/
:Logistic Regression:从入门到精通
http://www.tianyancha.com/research/LR_intro.pdf
:加载MNIST数据类
package com.bbw5.ml.spark.dataimport java.io.Fileimport java.io.FileInputStreamimport java.nio.ByteBufferimport scala.collection.mutable.ArrayBufferimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.linalg.Vector/** * http://yann.lecun.com/exdb/mnist/ * */class MNISTData(val dataDir: String, val numCount: Int = 2) { val trainLabeFileName = "/train-labels-idx1-ubyte/train-labels.idx1-ubyte" val trainImageFileName = "/train-images-idx3-ubyte/train-images.idx3-ubyte" val testLabeFileName = "/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte" val testImageFileName = "/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte" /** * 安全打开文件流方法 */ def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B = try { f(resource) } finally { resource.close() } def loadTrainLabelData(): Array[Byte] = { loadLabelData(trainLabeFileName) } def loadTestLabelData(): Array[Byte] = { loadLabelData(testLabeFileName) } /** * 加载MNIST train label数据 */ def loadLabelData(filename: String): Array[Byte] = { val file = new File(dataDir + filename) val in = new FileInputStream(file) val labelDS = new Array[Byte](file.length.toInt) using(new FileInputStream(file)) { source => { in.read(labelDS) } } //32 bit integer 0x00000801(2049) magic number (MSB first--high endian) val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt println(s"magicLabelNum=$magicLabelNum") //32 bit integer 60000 number of items val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt println(s"numOfLabelItems=$numOfLabelItems") //打印测试数据 for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) { println(s"image$index is $e") } labelDS } def loadTrainImageData(): Array[Byte] = { loadImageData(trainImageFileName) } def loadTestImageData(): Array[Byte] = { loadImageData(testImageFileName) } /** * 加载MNIST train data数据 */ def loadImageData(filename: String): Array[Byte] = { val file = new File(dataDir + filename) val in = new FileInputStream(file) val trainingDS = new Array[Byte](file.length.toInt) using(new FileInputStream(file)) { source => { in.read(trainingDS) } } //32 bit integer 0x00000803(2051) magic number val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt println(s"magicNum=$magicNum") //32 bit integer 60000 number of items val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt println(s"numOfItems=$numOfItems") //32 bit integer 28 number of rows val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt println(s"numOfRows=$numOfRows") //32 bit integer 28 number of columns val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt println(s"numOfCols=$numOfCols") println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows)) trainingDS } def loadTrainData(): Array[(Double, Vector)] = { loadData(loadTrainLabelData, loadTrainImageData) } def loadTestData(): Array[(Double, Vector)] = { loadData(loadTestLabelData, loadTestImageData) } def loadData(loadLabelFunc: () => Array[Byte], loadImageFunc: () => Array[Byte]): Array[(Double, Vector)] = { val labelDS: Array[Byte] = loadLabelFunc() val labels = labelDS.drop(8) val trainingDS: Array[Byte] = loadImageFunc() val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt val itemsBuffer = new ArrayBuffer[Array[Byte]] for (i <- 0 until numOfItems) { //16->16 + 28 * 28 //16 + 28 * 28->16 + 2*28 * 28 itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28) } println("numOfImages=" + itemsBuffer.length) val itemsArray = itemsBuffer.toArray val data = labels.zip(itemsArray) //打印测试数据概况 println("image digit count:") data.groupBy(a => a._1).mapValues(b => b.size).foreach(println) //only 0/1 image data.filter(p => p._1 < numCount).map(p => (p._1.toDouble, Vectors.dense(p._2.map(c => c.toDouble)))) }}object MNISTData { def loadTrainData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = { new MNISTData(dataDir, numCount).loadTrainData() } def loadTestData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = { new MNISTData(dataDir, numCount).loadTestData() }}:使用ML API进行分类
package com.bbw5.ml.sparkimport org.apache.spark.SparkConfimport org.apache.spark.SparkContextimport org.apache.spark.sql.SQLContextimport com.bbw5.ml.spark.data.MNISTDataimport org.apache.spark.ml.regression.LinearRegressionimport org.apache.spark.ml.tuning.ParamGridBuilderimport org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.evaluation.RegressionEvaluatorimport org.apache.spark.ml.tuning.TrainValidationSplitimport org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.classification.BinaryLogisticRegressionSummaryimport org.apache.spark.ml.classification.LogisticRegressionModelimport java.util.Date/** * 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类 * * @author baibaiw5 */object LogisticRegression4MNIST { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST") val sc = new SparkContext(sparkConf) val sqlContext = new org.apache.spark.sql.SQLContext(sc) tvSplit(sc, sqlContext) } def printResult(bestModel: LogisticRegressionModel) { bestModel.save("D:/Develop/Model/MNIST-LR-"+System.currentTimeMillis()) println("bestModel.params:" + bestModel.extractParamMap) val trainingSummary = bestModel.summary // Obtain the objective per iteration. val objectiveHistory = trainingSummary.objectiveHistory println("print lost in every step:") objectiveHistory.foreach(loss => println(loss)) val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. println("print roc in every step:") binarySummary.roc.show(binarySummary.roc.count.toInt) println("print recall,precision every step:") binarySummary.pr.show(binarySummary.pr.count.toInt) //0.9409024010447935 println("areaUnderROC="+binarySummary.areaUnderROC) } def tvSplit(sc: SparkContext, sqlContext: SQLContext) { val dataDir = "I:/DM-dataset/MNIST/" import sqlContext.implicits._ val training = sc.parallelize(MNISTData.loadTrainData(dataDir), 4).toDF("label", "features").cache() training.describe("label").show val test = sc.parallelize(MNISTData.loadTestData(dataDir), 4).toDF("label", "features").cache() test.describe("label").show val lr = new LogisticRegression() val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0001, 0.01, 1.0)).addGrid(lr.fitIntercept).addGrid(lr.maxIter, Array(100)).addGrid(lr.elasticNetParam, Array(0.1, 0.5, 1.0)).build() // 80% of the data will be used for training and the remaining 20% for validation. val trainValidationSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8) // Run train validation split, and choose the best set of parameters. val model = trainValidationSplit.fit(training) // Make predictions on test data. model is the model with combination of parameters // that performed best. val testDF = model.transform(test) testDF.select("label", "prediction").show() testDF.groupBy("label", "prediction").count().show() printResult(model.bestModel.asInstanceOf[LogisticRegressionModel]) }}:在测试集上测试后的结果AUC为 0.9409024010447935,准确率为97%
+-----+----------+-----+ |label|prediction|count| +-----+----------+-----+ | 1.0| 1.0| 1106| | 0.0| 0.0| 962| | 0.0| 1.0| 18| | 1.0| 0.0| 29| +-----+----------+-----+
1 0
- 【spark】采用LogisticRegression(ML API篇)对MNIST的0-1数字进行识别
- 【spark+python】采用LogisticRegression(MLLib)对MNIST的0-1数字进行识别
- 【spark】采用MultilayerPerceptron对MNIST的0-9数字进行识别
- 使用Spark MLlib的逻辑回归(LogisticRegression)进行用户分类预测识别
- TensorFlow学习笔记(1):使用softmax对手写体数字(MNIST数据集)进行识别
- 勉强算升2级吧----用mnist训练好的model对自己手写的数字进行分类识别
- 使用logisticRegression识别手写数字
- tensorflow下对MNIST数据集进行识别的程序代码
- tensorflow进行MNIST手写数字识别-CNN
- tensorflow进行MNIST手写数字识别-LSTM
- 使用PCA + KNN对MNIST数据集进行手写数字识别 python
- 用卷积神经网络对mnist进行数字识别程序(tensorflow)
- MNIST手写数字的识别——CNN篇
- MNIST手写数字的识别——DNN篇
- MNIST手写数字的识别——kNN篇
- 基于tensorflow的MNIST手写数字识别--入门篇
- 基于tensorflow的MNIST手写数字识别
- 基于tensorflow的MNIST手写数字识别
- VS 报cmath(19): error C2061: 语法错误: 标识符“acosf” 错误
- 读书笔记--推荐系统实践(2)
- 第2周项目2-就拿胖子说事
- linux添加系统调用总结(内核版本4.4.4)
- Python爬虫学习笔记(2)-单线程爬虫
- 【spark】采用LogisticRegression(ML API篇)对MNIST的0-1数字进行识别
- JavaScript 原型概念深入理解
- HDU5131(模拟)
- hdu5222(拓扑排序+并查集)
- Debian下iceweasel(FireFox)缓存目录下的视频文件
- Python爬虫学习笔记(3)-XPath与多线程爬虫
- bzoj 3625 小朋友和二叉树 多项式开根
- UI学习第09天
- POJ 2362 Square (搜索 + 剪枝)