【spark】采用LogisticRegression(ML API篇)对MNIST的0-1数字进行识别

来源:互联网 发布:淘宝上下架工具 编辑:程序博客网 时间:2024/06/07 12:08

:ROC曲线概念

http://blog.csdn.net/abcjennifer/article/details/7359370

:Recall-Precision概念

http://blog.csdn.net/pirage/article/details/9851339

:下载MNIST数据集

http://yann.lecun.com/exdb/mnist/

:Logistic Regression:从入门到精通

http://www.tianyancha.com/research/LR_intro.pdf

:加载MNIST数据类

package com.bbw5.ml.spark.dataimport java.io.Fileimport java.io.FileInputStreamimport java.nio.ByteBufferimport scala.collection.mutable.ArrayBufferimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.linalg.Vector/** * http://yann.lecun.com/exdb/mnist/ * */class MNISTData(val dataDir: String, val numCount: Int = 2) {  val trainLabeFileName = "/train-labels-idx1-ubyte/train-labels.idx1-ubyte"  val trainImageFileName = "/train-images-idx3-ubyte/train-images.idx3-ubyte"  val testLabeFileName = "/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte"  val testImageFileName = "/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte"  /**   * 安全打开文件流方法   */  def using[A <: { def close(): Unit }, B](resource: A)(f: A => B): B =    try {      f(resource)    } finally {      resource.close()    }  def loadTrainLabelData(): Array[Byte] = {    loadLabelData(trainLabeFileName)  }  def loadTestLabelData(): Array[Byte] = {    loadLabelData(testLabeFileName)  }  /**   * 加载MNIST train label数据   */  def loadLabelData(filename: String): Array[Byte] = {    val file = new File(dataDir + filename)    val in = new FileInputStream(file)    val labelDS = new Array[Byte](file.length.toInt)    using(new FileInputStream(file)) { source =>      {        in.read(labelDS)      }    }    //32 bit integer  0x00000801(2049) magic number (MSB first--high endian)     val magicLabelNum = ByteBuffer.wrap(labelDS.take(4)).getInt    println(s"magicLabelNum=$magicLabelNum")    //32 bit integer  60000            number of items     val numOfLabelItems = ByteBuffer.wrap(labelDS.slice(4, 8)).getInt    println(s"numOfLabelItems=$numOfLabelItems")    //打印测试数据    for ((e, index) <- labelDS.drop(8).take(3).zipWithIndex) {      println(s"image$index is $e")    }    labelDS  }  def loadTrainImageData(): Array[Byte] = {    loadImageData(trainImageFileName)  }  def loadTestImageData(): Array[Byte] = {    loadImageData(testImageFileName)  }  /**   * 加载MNIST train data数据   */  def loadImageData(filename: String): Array[Byte] = {    val file = new File(dataDir + filename)    val in = new FileInputStream(file)    val trainingDS = new Array[Byte](file.length.toInt)    using(new FileInputStream(file)) { source =>      {        in.read(trainingDS)      }    }    //32 bit integer  0x00000803(2051) magic number     val magicNum = ByteBuffer.wrap(trainingDS.take(4)).getInt    println(s"magicNum=$magicNum")    //32 bit integer  60000            number of items     val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt    println(s"numOfItems=$numOfItems")    //32 bit integer  28               number of rows     val numOfRows = ByteBuffer.wrap(trainingDS.slice(8, 12)).getInt    println(s"numOfRows=$numOfRows")    //32 bit integer  28               number of columns     val numOfCols = ByteBuffer.wrap(trainingDS.slice(12, 16)).getInt    println(s"numOfCols=$numOfCols")    println(s"numOfItems=" + trainingDS.drop(16).length + "=" + (numOfItems * numOfRows * numOfRows))    trainingDS  }  def loadTrainData(): Array[(Double, Vector)] = {    loadData(loadTrainLabelData, loadTrainImageData)  }  def loadTestData(): Array[(Double, Vector)] = {    loadData(loadTestLabelData, loadTestImageData)  }  def loadData(loadLabelFunc: () => Array[Byte], loadImageFunc: () => Array[Byte]): Array[(Double, Vector)] = {    val labelDS: Array[Byte] = loadLabelFunc()    val labels = labelDS.drop(8)    val trainingDS: Array[Byte] = loadImageFunc()    val numOfItems = ByteBuffer.wrap(trainingDS.slice(4, 8)).getInt    val itemsBuffer = new ArrayBuffer[Array[Byte]]    for (i <- 0 until numOfItems) {      //16->16 + 28 * 28      //16 + 28 * 28->16 + 2*28 * 28      itemsBuffer += trainingDS.slice(16 + i * 28 * 28, 16 + (i + 1) * 28 * 28)    }    println("numOfImages=" + itemsBuffer.length)    val itemsArray = itemsBuffer.toArray    val data = labels.zip(itemsArray)    //打印测试数据概况    println("image digit count:")    data.groupBy(a => a._1).mapValues(b => b.size).foreach(println)    //only 0/1 image    data.filter(p => p._1 < numCount).map(p => (p._1.toDouble, Vectors.dense(p._2.map(c => c.toDouble))))  }}object MNISTData {  def loadTrainData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {    new MNISTData(dataDir, numCount).loadTrainData()  }  def loadTestData(dataDir: String, numCount: Int = 2): Array[(Double, Vector)] = {    new MNISTData(dataDir, numCount).loadTestData()  }}
:使用ML API进行分类

package com.bbw5.ml.sparkimport org.apache.spark.SparkConfimport org.apache.spark.SparkContextimport org.apache.spark.sql.SQLContextimport com.bbw5.ml.spark.data.MNISTDataimport org.apache.spark.ml.regression.LinearRegressionimport org.apache.spark.ml.tuning.ParamGridBuilderimport org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.evaluation.RegressionEvaluatorimport org.apache.spark.ml.tuning.TrainValidationSplitimport org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.classification.BinaryLogisticRegressionSummaryimport org.apache.spark.ml.classification.LogisticRegressionModelimport java.util.Date/** * 使用LogisticRegression对MNIST手写数字识别数据集中的0,1的数据进行分类 * * @author baibaiw5 */object LogisticRegression4MNIST {  def main(args: Array[String]) {    val sparkConf = new SparkConf().setAppName("LogisticRegression4MNIST")    val sc = new SparkContext(sparkConf)    val sqlContext = new org.apache.spark.sql.SQLContext(sc)    tvSplit(sc, sqlContext)  }  def printResult(bestModel: LogisticRegressionModel) {    bestModel.save("D:/Develop/Model/MNIST-LR-"+System.currentTimeMillis())    println("bestModel.params:" + bestModel.extractParamMap)        val trainingSummary = bestModel.summary    // Obtain the objective per iteration.    val objectiveHistory = trainingSummary.objectiveHistory    println("print lost in every step:")    objectiveHistory.foreach(loss => println(loss))    val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary]    // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.    println("print roc in every step:")    binarySummary.roc.show(binarySummary.roc.count.toInt)    println("print recall,precision every step:")    binarySummary.pr.show(binarySummary.pr.count.toInt)    //0.9409024010447935    println("areaUnderROC="+binarySummary.areaUnderROC)  }  def tvSplit(sc: SparkContext, sqlContext: SQLContext) {    val dataDir = "I:/DM-dataset/MNIST/"    import sqlContext.implicits._    val training = sc.parallelize(MNISTData.loadTrainData(dataDir), 4).toDF("label", "features").cache()    training.describe("label").show    val test = sc.parallelize(MNISTData.loadTestData(dataDir), 4).toDF("label", "features").cache()    test.describe("label").show    val lr = new LogisticRegression()    val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0001, 0.01, 1.0)).addGrid(lr.fitIntercept).addGrid(lr.maxIter, Array(100)).addGrid(lr.elasticNetParam, Array(0.1, 0.5, 1.0)).build()    // 80% of the data will be used for training and the remaining 20% for validation.      val trainValidationSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new BinaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8)    // Run train validation split, and choose the best set of parameters.    val model = trainValidationSplit.fit(training)    // Make predictions on test data. model is the model with combination of parameters    // that performed best.    val testDF = model.transform(test)    testDF.select("label", "prediction").show()    testDF.groupBy("label", "prediction").count().show()    printResult(model.bestModel.asInstanceOf[LogisticRegressionModel])  }}
:在测试集上测试后的结果AUC为 0.9409024010447935,准确率为97%

        +-----+----------+-----+        |label|prediction|count|        +-----+----------+-----+        |  1.0|       1.0| 1106|        |  0.0|       0.0|  962|        |  0.0|       1.0|   18|        |  1.0|       0.0|   29|        +-----+----------+-----+



1 0
原创粉丝点击