深度学习Deeplearning4j 入门实战(2):Deeplearning4j 手写体数字识别Spark实现

来源:互联网 发布:js 去掉换行符 编辑:程序博客网 时间:2024/05/10 05:21

在前两天的博客中,我们用Deeplearning4j做了Mnist数据集的分类。算是第一个深度学习的应用。像Mnist数据集这样图片尺寸不大,而且是黑白的开源图片集在本地完成训练是可以的,毕竟我们用了Lenet这样相对简单的网络结构,而且本地的机器配置也有8G左右的内存。但实际生产中,图片的数量要多得多,尺寸也大得多,用的网络也会是AlexNet、GoogLenet这样更多层数的网络,所以往往我们需要用集群来解决计算资源的问题。由于Deeplearning4j本身基于Spark实现了神经网络的分布式训练,所以我们就以此作为我们的解决方案。

我们还是以Mnist数据集为例来做Deeplearning4j的第一个Spark版本的应用。首先需要在上一篇博客的基础上,在pom里面加入新的依赖:

[html] view plain copy
  1. <dependency>  
  2. <groupId>org.nd4j</groupId>  
  3. <artifactId>nd4j-kryo_${scala.binary.version}</artifactId>  
  4. <version>${nd4j.version}</version>  
  5. </dependency>  
这个是为了将Nd4j的序列化形式从Java默认的形式转到kryo的格式,以此提高序列化的效率。如果在代码中不为该类注册kryo的序列化格式,那么训练的时候会抛异常。
接着代码分为2个部分,一个部分是将Mnist数据集在本地以JavaRDD<DataSet>的形式存到磁盘并最终推到HDFS上作为Spark job的输入数据源。另一个部分则是模型的训练和保存。

第一部分的逻辑大致如下:本地建立Spark任务-->获取所有Mnist图片的路径-->读取图片并提取特征,打上标注,以DataSet的形式作为一张图片的wrapper-->将所有图片构成的JavaRDD<DataSet>存储下来。

这里原始的Mnist数据集是以图片形式存在,不再是二进制格式的数据。这个例子这样处理,也是方便日后用同样的方式读取一般的图片。Mnist的图片如下:


[java] view plain copy
  1. SparkConf conf = new SparkConf()  
  2.                 .setMaster("local[*]")  //local mode  
  3.                 .set("spark.kryo.registrator""org.nd4j.Nd4jRegistrator")  
  4.                 .setAppName("Mnist Java Spark (Java)");  
  5. JavaSparkContext jsc = new JavaSparkContext(conf);  
  6.   
  7. final List<String> lstLabelNames = Arrays.asList("零","一","二","三","四","五","六","七","八","九");  //Chinese Label  
  8. final ImageLoader imageLoader = new ImageLoader(28281);             //Load Image  
  9. final DataNormalization scaler = new ImagePreProcessingScaler(01);    //Normalize  
  10.   
  11. String srcPath = args[0];  
  12. FileSystem hdfs = FileSystem.get(URI.create(srcPath),jsc.hadoopConfiguration());    //hdfs read local file system  
  13. FileStatus[] fileList = hdfs.listStatus(new Path(srcPath));  
  14. List<String> lstFilePath = new ArrayList<>();  
  15. for( FileStatus fileStatus :  fileList){  
  16.     lstFilePath.add(srcPath + "/" + fileStatus.getPath().getName());  
  17. }  
  18. JavaRDD<String> javaRDDImagePath = jsc.parallelize(lstFilePath);  
  19. JavaRDD<DataSet> javaRDDImageTrain = javaRDDImagePath.map(new Function<String, DataSet>() {  
  20.   
  21.     @Override  
  22.     public DataSet call(String imagePath) throws Exception {  
  23.         FileSystem fs = FileSystem.get(new Configuration());  
  24.         DataInputStream in = fs.open(new Path(imagePath));  
  25.         INDArray features = imageLoader.asRowVector(in);            //features tensor  
  26.         String[] tokens = imagePath.split("\\/");  
  27.         String label = tokens[tokens.length-1].split("\\.")[0];       
  28.         int intLabel = Integer.parseInt(label);  
  29.         INDArray labels = Nd4j.zeros(10);                           //labels tensor                       
  30.         labels.putScalar(0, intLabel, 1.0);  
  31.         DataSet trainData = new DataSet(features, labels);          //DataSet, wrapper of features and labels  
  32.         trainData.setLabelNames(lstLabelNames);  
  33.         scaler.preProcess(trainData);                               //normalize  
  34.         fs.close();  
  35.         return trainData;  
  36.     }  
  37. });  
  38. javaRDDImageTrain.saveAsObjectFile("mnistNorm.dat");        //save training data  
这里有几点需要解释。
1.用hdfs.filesystem来获取文件。用Java原生态的File来操作也是完全可以的。只不过,这样读取文件的方式,同时适用于读取本地和HDFS上的文件。

2.ImageLoader类。这个类是用来读取图片文件的。类似的还有一个类,叫NativeImageLoader。不同的在于,NativeImageLoader是调用了OpenCV的相关方法来对图片做处理的,效率更高,因此推荐使用NativeImageLoader

保存的RDD的形式如下图:



然后,讲下模型训练任务的逻辑。读取HDFS上的以DataSet形式存储的Mnist文件-->定义参数中心服务-->定义神经网络结构(Lenet)--> 训练网络-->保存训练好的模型。首先看前两步的操作:
[java] view plain copy
  1. SparkConf conf = new SparkConf()  
  2.                       .set("spark.kryo.registrator""org.nd4j.Nd4jRegistrator")  //register kryo for nd4j  
  3.                       .setAppName("Mnist Java Spark (Java)");  
  4.   final String imageFilePath = args[0];  
  5.   final int numEpochs = Integer.parseInt(args[1]);  
  6.   final String modelPath = args[2];  
  7.   final int numBatch = Integer.parseInt(args[3]);  
  8.   //  
  9.   JavaSparkContext jsc = new JavaSparkContext(conf);  
  10.   //  
  11.   JavaRDD<DataSet> javaRDDImageTrain = jsc.objectFile(imageFilePath);     //load image data from hdfs  
  12.   ParameterAveragingTrainingMaster trainMaster = new ParameterAveragingTrainingMaster.Builder(numBatch)   //weight average service  
  13.                                                       .workerPrefetchNumBatches(0)  
  14.                                                       .saveUpdater(true)  
  15.                                                       .averagingFrequency(5)  
  16.                                                       .batchSizePerWorker(numBatch)  
这里我们获取传入的一些参数,如文件的hdfs路径,最后保存model的路径,mini-batch的大小(一般32,62,128这样的值为好,可以自行尝试),总的训练的轮次epoch。
这里需要解释的是ParameterAveragingTrainingMaster这个类。这个类的作用是用于将spark worker节点上各自计算的权重收回到driver节点上进行加权平均,并将最新的权重广播到worker节点上。也即为:将各个工作节点的参数的均值作为全局参数值。这种分布式机器学习中,数据并行化的一种操作。

下面是定义神经网络结构和训练网络:

[java] view plain copy
  1. int nChannels = 1;  
  2. int outputNum = 10;  
  3. int iterations = 1;  
  4. int seed = 123;  
  5. MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()  //define lenent  
  6.         .seed(seed)  
  7.         .iterations(iterations)  
  8.         .regularization(true).l2(0.0005)  
  9.         .learningRate(0.1)  
  10.         .learningRateScoreBasedDecayRate(0.5)  
  11.         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)  
  12.         .updater(Updater.ADAM)  
  13.         .list()  
  14.         .layer(0new ConvolutionLayer.Builder(55)  
  15.                 .nIn(nChannels)  
  16.                 .stride(11)  
  17.                 .nOut(20)  
  18.                 .weightInit(WeightInit.XAVIER)  
  19.                 .activation("relu")  
  20.                 .build())  
  21.         .layer(1new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)  
  22.                 .kernelSize(22)  
  23.                 .build())  
  24.         .layer(2new ConvolutionLayer.Builder(55)  
  25.                 .nIn(20)  
  26.                 .nOut(50)  
  27.                 .stride(2,2)  
  28.                 .weightInit(WeightInit.XAVIER)  
  29.                 .activation("relu")  
  30.                 .build())  
  31.         .layer(3new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)  
  32.                 .kernelSize(22)  
  33.                 .build())  
  34.         .layer(4new DenseLayer.Builder().activation("relu")  
  35.                 .weightInit(WeightInit.XAVIER)  
  36.                 .nOut(500).build())  
  37.         .layer(5new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)  
  38.                 .nOut(outputNum)  
  39.                 .weightInit(WeightInit.XAVIER)  
  40.                 .activation("softmax")  
  41.                 .build())  
  42.         .backprop(true).pretrain(false);  
  43. new ConvolutionLayerSetup(builder,28,28,1);  
  44.   
  45. MultiLayerConfiguration netconf = builder.build();  
  46. MultiLayerNetwork net = new MultiLayerNetwork(netconf);  
  47. net.setListeners(new ScoreIterationListener(1));  
  48. net.init();  
  49. SparkDl4jMultiLayer sparkNetwork = new SparkDl4jMultiLayer(jsc, net, trainMaster);  
  50. //train the network on Spark  
  51. forint i = 0; i < numEpochs; ++i ){  
  52.     sparkNetwork.fit(javaRDDImageTrain);  
  53.     System.out.println("----- Epoch " + i + " complete -----");  
  54.     Evaluation evalActual = sparkNetwork.evaluate(javaRDDImageTrain);  
  55.     System.out.println(evalActual.stats());  
  56. }  
这部分没有什么特别的地方,和单机的形式差不太多。值得说明的就是,我们在每一轮次的训练后,直接预测全部的训练数据来做评估,并没有做交叉验证。当然,做交叉验证也是完全可以的。
最后一部分是保存模型到hdfs上:

[java] view plain copy
  1. //save model  
  2. FileSystem hdfs = FileSystem.get(jsc.hadoopConfiguration());  
  3. Path hdfsPath = new Path(modelPath);  
  4. FSDataOutputStream outputStream = hdfs.create(hdfsPath);  
  5. MultiLayerNetwork trainedNet = sparkNetwork.getNetwork();  
  6. ModelSerializer.writeModel(trainedNet, outputStream, true);  
到此coding的部分就结束了,我们构建了在Spark进行分布式深度神经网络的训练并保存了模型。Spark的提交命令如下:

spark-submit --master yarn-cluster --executor-memory 5g --num-executors 16 --driver-memory 8g --conf "spark.executor.extraJavaOptions=-Dorg.bytedeco.javacpp.maxbytes=2921225472"  --conf spark.yarn.executor.memoryOverhead=5000

需要说明的是--conf后面的内容,因为Nd4j在计算的时候,实际需要两部分的内存:on-heap memory和off-heap memory。前者就是jvm为开辟对象所需内存,后者是C++的内存。Nd4j为了效率,在底层是通过JavaCPP调用C++进行计算的。如果不显示地申请C++的内存,那默认会从on-heap中分出10%给off-heap,但这样可能会不够。所以我们显示地申请off-heap内存。
下面这张图是正常的Spark UI显示的Deeplearning4j的训练过程:


然后,我们看下训练的结果:

[plain] view plain copy
  1. ----- Epoch 149 complete -----  
  2.   
  3. Examples labeled as 0 classified by model as 0: 4011 times  
  4. Examples labeled as 0 classified by model as 1: 2 times  
  5. Examples labeled as 0 classified by model as 2: 14 times  
  6. Examples labeled as 0 classified by model as 4: 9 times  
  7. Examples labeled as 0 classified by model as 5: 11 times  
  8. Examples labeled as 0 classified by model as 6: 28 times  
  9. Examples labeled as 0 classified by model as 7: 6 times  
  10. Examples labeled as 0 classified by model as 8: 40 times  
  11. Examples labeled as 0 classified by model as 9: 11 times  
  12. Examples labeled as 1 classified by model as 0: 1 times  
  13. Examples labeled as 1 classified by model as 1: 4598 times  
  14. Examples labeled as 1 classified by model as 2: 20 times  
  15. Examples labeled as 1 classified by model as 3: 7 times  
  16. Examples labeled as 1 classified by model as 4: 12 times  
  17. Examples labeled as 1 classified by model as 5: 3 times  
  18. Examples labeled as 1 classified by model as 6: 8 times  
  19. Examples labeled as 1 classified by model as 7: 10 times  
  20. Examples labeled as 1 classified by model as 8: 20 times  
  21. Examples labeled as 1 classified by model as 9: 5 times  
  22. Examples labeled as 2 classified by model as 0: 13 times  
  23. Examples labeled as 2 classified by model as 1: 20 times  
  24. Examples labeled as 2 classified by model as 2: 3910 times  
  25. Examples labeled as 2 classified by model as 3: 63 times  
  26. Examples labeled as 2 classified by model as 4: 22 times  
  27. Examples labeled as 2 classified by model as 5: 5 times  
  28. Examples labeled as 2 classified by model as 6: 4 times  
  29. Examples labeled as 2 classified by model as 7: 70 times  
  30. Examples labeled as 2 classified by model as 8: 54 times  
  31. Examples labeled as 2 classified by model as 9: 16 times  
  32. Examples labeled as 3 classified by model as 0: 2 times  
  33. Examples labeled as 3 classified by model as 1: 10 times  
  34. Examples labeled as 3 classified by model as 2: 55 times  
  35. Examples labeled as 3 classified by model as 3: 4104 times  
  36. Examples labeled as 3 classified by model as 4: 5 times  
  37. Examples labeled as 3 classified by model as 5: 53 times  
  38. Examples labeled as 3 classified by model as 6: 2 times  
  39. Examples labeled as 3 classified by model as 7: 42 times  
  40. Examples labeled as 3 classified by model as 8: 56 times  
  41. Examples labeled as 3 classified by model as 9: 22 times  
  42. Examples labeled as 4 classified by model as 0: 5 times  
  43. Examples labeled as 4 classified by model as 1: 6 times  
  44. Examples labeled as 4 classified by model as 2: 5 times  
  45. Examples labeled as 4 classified by model as 4: 3960 times  
  46. Examples labeled as 4 classified by model as 5: 3 times  
  47. Examples labeled as 4 classified by model as 6: 22 times  
  48. Examples labeled as 4 classified by model as 7: 9 times  
  49. Examples labeled as 4 classified by model as 8: 16 times  
  50. Examples labeled as 4 classified by model as 9: 46 times  
  51. Examples labeled as 5 classified by model as 0: 5 times  
  52. Examples labeled as 5 classified by model as 1: 7 times  
  53. Examples labeled as 5 classified by model as 2: 5 times  
  54. Examples labeled as 5 classified by model as 3: 40 times  
  55. Examples labeled as 5 classified by model as 4: 8 times  
  56. Examples labeled as 5 classified by model as 5: 3626 times  
  57. Examples labeled as 5 classified by model as 6: 27 times  
  58. Examples labeled as 5 classified by model as 7: 5 times  
  59. Examples labeled as 5 classified by model as 8: 66 times  
  60. Examples labeled as 5 classified by model as 9: 6 times  
  61. Examples labeled as 6 classified by model as 0: 9 times  
  62. Examples labeled as 6 classified by model as 1: 6 times  
  63. Examples labeled as 6 classified by model as 2: 5 times  
  64. Examples labeled as 6 classified by model as 3: 2 times  
  65. Examples labeled as 6 classified by model as 4: 47 times  
  66. Examples labeled as 6 classified by model as 5: 34 times  
  67. Examples labeled as 6 classified by model as 6: 3990 times  
  68. Examples labeled as 6 classified by model as 8: 43 times  
  69. Examples labeled as 6 classified by model as 9: 1 times  
  70. Examples labeled as 7 classified by model as 0: 6 times  
  71. Examples labeled as 7 classified by model as 1: 15 times  
  72. Examples labeled as 7 classified by model as 2: 57 times  
  73. Examples labeled as 7 classified by model as 3: 45 times  
  74. Examples labeled as 7 classified by model as 4: 22 times  
  75. Examples labeled as 7 classified by model as 5: 4 times  
  76. Examples labeled as 7 classified by model as 7: 4168 times  
  77. Examples labeled as 7 classified by model as 8: 21 times  
  78. Examples labeled as 7 classified by model as 9: 63 times  
  79. Examples labeled as 8 classified by model as 0: 15 times  
  80. Examples labeled as 8 classified by model as 1: 11 times  
  81. Examples labeled as 8 classified by model as 2: 23 times  
  82. Examples labeled as 8 classified by model as 3: 17 times  
  83. Examples labeled as 8 classified by model as 4: 19 times  
  84. Examples labeled as 8 classified by model as 5: 27 times  
  85. Examples labeled as 8 classified by model as 6: 35 times  
  86. Examples labeled as 8 classified by model as 7: 15 times  
  87. Examples labeled as 8 classified by model as 8: 3848 times  
  88. Examples labeled as 8 classified by model as 9: 53 times  
  89. Examples labeled as 9 classified by model as 0: 21 times  
  90. Examples labeled as 9 classified by model as 1: 3 times  
  91. Examples labeled as 9 classified by model as 2: 8 times  
  92. Examples labeled as 9 classified by model as 3: 26 times  
  93. Examples labeled as 9 classified by model as 4: 109 times  
  94. Examples labeled as 9 classified by model as 5: 23 times  
  95. Examples labeled as 9 classified by model as 6: 6 times  
  96. Examples labeled as 9 classified by model as 7: 62 times  
  97. Examples labeled as 9 classified by model as 8: 42 times  
  98. Examples labeled as 9 classified by model as 9: 3888 times  
  99.   
  100.   
  101. ==========================Scores========================================  
  102.  Accuracy:        0.9548  
  103.  Precision:       0.9546  
  104.  Recall:          0.9547  
  105.  F1 Score:        0.9547  
  106. ========================================================================  
在150轮的训练过后,模型的准确率达到了95.48%。误判的情况也列在上面了。

到此,在Spark上进行Mnist数据集的训练和评估就完成了。总结一下就是,先将数据以RDD的形式保存到HDFS上,然后建模读取RDD并训练模型。其实,将图片存在HDFS上也是一种方案,但是HDFS的一个block可能需要占用32M,64M这样的空间。因此图片这样的小文件,是很占用集群的存储空间的。并且,当图片数量很多的时候,我们会为了读取图片频繁地和HDFS建立和释放网络链接,这样同样消耗HDFS的资源。因此我们选择先在本地存储RDD的形式来处理。其实分布式的机器学习有很多策略,比如数据的并行化和模型的并行化,这里只是一笔掠过,待自己研究清楚了再写点东西。最后就是模型的调参。这里面我们也没有提,其实是极其重要的。因为目前,还没有非常权威的,或者定义的调参方案,因为训练过程每个人是不同的,所以只能结合自己的训练情况来调。一般当loss不下降的时候,调小学习率,batch-size也试着调小来看看效果,分布式的学习率较单机的要大些,这些原则去调。
0 0
原创粉丝点击