Deeplearning4j 实战(1):Deeplearning4j 手写体数字识别【转】

来源:互联网 发布:淘宝 心语星店 编辑:程序博客网 时间:2024/06/02 07:30

from:http://blog.csdn.net/wangongxi/article/details/54576594

最近这几年,深度学习很火,包括自己在内的很多对机器学习还是一知半解的小白也开始用深度学习做些应用。由于小白的等级不高,算法自己写不出来,所以就用了开源库。Deep Learning的开源库有多,如果以语言来划分的话,就有Python系列的tensowflow,theano,keras,C/C++系列的Caffe,还有Lua系列的torch等等。但咱们公司是用Java为主,大部分项目最终也是做成一个Java Web的服务,所以我最终选择了Deeplearning4j。

    Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持Spark。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。

    一般开始使用别人的代码库,都会先跑一些demo,或者说Hello World的例子,就好像学习一门编程语言一样,第一行代码都是打印Hello World。Deep Learning的Hello World的例子一般是两个,一个是Mnist数据集的分类,另一个就是Word2Vec找相似词。由于Word2Vec并不是严格意义上的深度神经网络,因此这里就用Lenet网络处理Mnist数据集来作为Deep Learning的Hello World。Mnist是开源的28x28的黑白手写体数字图片集(http://yann.lecun.com/exdb/mnist/),其中包含6W张训练图片和1W张测试图片。至于Lenet的相关结构描述,可以参考这个链接:http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf。下面就详细讲述下,利用Deeplearning4j如何进行建模、训练和预测评估。

    首先,我们建立一个maven项目。然后在pom文件里加入Deeplearning4j的一些相关依赖。最主要的有三个:deeplearning4j-core,datavec,nd4j。deeplearning4j-core是神经网络结构实现的代码,nd4j是用于做张量运算的库,通过JavaCPP来调用编译好的C++库(可选:ATAL, MKL, 和OpenBLAS),datavec则主要负责数据的ETL。具体可见代码:

[html] view plain copy
  1. <properties>  
  2.   <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>  
  3.   <nd4j.version>0.7.1</nd4j.version>  
  4.   <dl4j.version>0.7.1</dl4j.version>  
  5.   <datavec.version>0.7.1</datavec.version>  
  6.   <scala.binary.version>2.10</scala.binary.version>  
  7. </properties>  
  8. <dependencies>  
  9. <dependency>  
  10.     <groupId>org.nd4j</groupId>  
  11.     <artifactId>nd4j-native</artifactId>   
  12.     <version>${nd4j.version}</version>  
  13. </dependency>  
  14. <dependency>  
  15.     <groupId>org.deeplearning4j</groupId>  
  16.     <artifactId>dl4j-spark_2.11</artifactId>  
  17.     <version>${dl4j.version}</version>  
  18. </dependency>  
  19.      <dependency>  
  20.           <groupId>org.datavec</groupId>  
  21.           <artifactId>datavec-spark_${scala.binary.version}</artifactId>  
  22.           <version>${datavec.version}</version>  
  23.     </dependency>  
  24.       <dependency>  
  25.    <groupId>org.deeplearning4j</groupId>  
  26.    <artifactId>deeplearning4j-core</artifactId>  
  27.    <version>${dl4j.version}</version>  
  28. </dependency>  
  29. </dependencies>  

    这些依赖里面有和Spark相关的,主要是跑Spark要用到。不过没有关系,先引进来即可。

    接着,我们解释下面的代码。我们先要定义一些具体的参数,比如分类的个数(outputNum),mini-batch的数量(batchSize)等等,具体在图中已经做了注释。需要说明的是MnistDataSetIterator这个迭代器类。这个类其实是一个读取二进制Mnist数据集的high-level的封装。通过debug我们可以发现,其中包括从网络中下载Mnist数据集,读取数据和标注,再构建迭代器的过程。在源码中,默认将下载的文件放在系统的user.home目录下,具体每个人不同会有所不同。由于我自己所处的环境网络不咋的,所以很有可能在利用这种high-level的接口的时候,因为下载Mnist数据失败而抛出异常,最终无法训练。所以,大家可以先自行下载好这些数据,然后按照源码的要求,放到相应的目录下并根据源码正确命名文件,那这样就依然可以利用这种high-level的接口。具体需要参考的是MnistDataFetcher类中相关代码。

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. int nChannels = 1;      //black & white picture, 3 if color image  
  2. int outputNum = 10;     //number of classification  
  3. int batchSize = 64;     //mini batch size for sgd  
  4. int nEpochs = 10;       //total rounds of training  
  5. int iterations = 1;     //number of iteration in each traning round  
  6. int seed = 123;         //random seed for initialize weights  
  7.   
  8. log.info("Load data....");  
  9. DataSetIterator mnistTrain = null;  
  10. DataSetIterator mnistTest = null;  
  11.   
  12. mnistTrain = new MnistDataSetIterator(batchSize, true12345);  
  13. mnistTest = new MnistDataSetIterator(batchSize, false12345);  

当我们正确读取数据后,我们需要定义具体的神经网络结构,这里我用的是Lenet,Deeplearning4j的实现参考了官网(https://github.com/deeplearning4j/dl4j-examples)的例子。具体代码如下:

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()  
  2.         .seed(seed)  
  3.         .iterations(iterations)  
  4.         .regularization(true).l2(0.0005)  
  5.         .learningRate(0.01)//.biasLearningRate(0.02)  
  6.         //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)  
  7.         .weightInit(WeightInit.XAVIER)  
  8.         .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)  
  9.         .updater(Updater.NESTEROVS).momentum(0.9)  
  10.         .list()  
  11.         .layer(0new ConvolutionLayer.Builder(55)  
  12.                 //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied  
  13.                 .nIn(nChannels)  
  14.                 .stride(11)  
  15.                 .nOut(20)  
  16.                 .activation("identity")  
  17.                 .build())  
  18.         .layer(1new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)  
  19.                 .kernelSize(2,2)  
  20.                 .stride(2,2)  
  21.                 .build())  
  22.         .layer(2new ConvolutionLayer.Builder(55)  
  23.                 //Note that nIn need not be specified in later layers  
  24.                 .stride(11)  
  25.                 .nOut(50)  
  26.                 .activation("identity")  
  27.                 .build())  
  28.         .layer(3new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)  
  29.                 .kernelSize(2,2)  
  30.                 .stride(2,2)  
  31.                 .build())  
  32.         .layer(4new DenseLayer.Builder().activation("relu")  
  33.                 .nOut(500).build())  
  34.         .layer(5new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)  
  35.                 .nOut(outputNum)  
  36.                 .activation("softmax")  
  37.                 .build())  
  38.         .backprop(true).pretrain(false)  
  39.         .cnnInputSize(28281);  
  40. // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel  
  41. //new ConvolutionLayerSetup(builder,28,28,1);  
  42.   
  43. MultiLayerConfiguration conf = builder.build();  
  44. MultiLayerNetwork model = new MultiLayerNetwork(conf);  
  45. model.init();          
  46. model.setListeners(new ScoreIterationListener(1));         // a listener which can print loss function score after each iteration  
可以发现,神经网络需要定义很多的超参数,学习率、正则化系数、卷积核的大小、激励函数等都是需要人为设定的。不同的超参数,对结果的影响很大,其实后来发现,很多时间都花在数据处理和调参方面。毕竟自己设计网络的能力有限,一般都是参考大牛的论文,然后自己照葫芦画瓢地实现。这里实现的Lenet的结构是:卷积-->下采样-->卷积-->下采样-->全连接。和原论文的结构基本一致。卷积核的大小也是参考的原论文。具体细节可参考之前发的论文链接。这里我们设置了一个Score的监听事件,主要是可以在训练的时候获取每一次权重更新后损失函数的收敛情况。后面一会有截图。

定义完网络结构之后,我们就可以对之前读取的数据进行训练和分类准确性评估。先看下代码:

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. forint i = 0; i < nEpochs; ++i ) {  
  2.     model.fit(mnistTrain);  
  3.     log.info("*** Completed epoch " + i + "***");  
  4.   
  5.     log.info("Evaluate model....");  
  6.     Evaluation eval = new Evaluation(outputNum);  
  7.     while(mnistTest.hasNext()){  
  8.         DataSet ds = mnistTest.next();            
  9.         INDArray output = model.output(ds.getFeatureMatrix(), false);  
  10.         eval.eval(ds.getLabels(), output);  
  11.     }  
  12.     log.info(eval.stats());  
  13.     mnistTest.reset();  
  14. }  

    相信这部分是比较容易理解的。每训练完一轮后,我们会对测试集合进行评估,然后打印出类似下面的结果。图中的上半部分是具体分类的统计,包括分对的和分错的图片数量都可以看得到。然后,我们耐心等待一段时间,可以看到经过10轮训练的Lenet对于Mnist数据集的分类准确率达到99%如下:

[plain] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. Examples labeled as 0 classified by model as 0: 974 times  
  2. Examples labeled as 0 classified by model as 6: 2 times  
  3. Examples labeled as 0 classified by model as 7: 2 times  
  4. Examples labeled as 0 classified by model as 8: 1 times  
  5. Examples labeled as 0 classified by model as 9: 1 times  
  6. Examples labeled as 1 classified by model as 0: 1 times  
  7. Examples labeled as 1 classified by model as 1: 1128 times  
  8. Examples labeled as 1 classified by model as 2: 1 times  
  9. Examples labeled as 1 classified by model as 3: 2 times  
  10. Examples labeled as 1 classified by model as 5: 1 times  
  11. Examples labeled as 1 classified by model as 6: 2 times  
  12. Examples labeled as 2 classified by model as 2: 1026 times  
  13. Examples labeled as 2 classified by model as 4: 1 times  
  14. Examples labeled as 2 classified by model as 6: 1 times  
  15. Examples labeled as 2 classified by model as 7: 3 times  
  16. Examples labeled as 2 classified by model as 8: 1 times  
  17. Examples labeled as 3 classified by model as 0: 1 times  
  18. Examples labeled as 3 classified by model as 1: 1 times  
  19. Examples labeled as 3 classified by model as 2: 1 times  
  20. Examples labeled as 3 classified by model as 3: 998 times  
  21. Examples labeled as 3 classified by model as 5: 3 times  
  22. Examples labeled as 3 classified by model as 7: 1 times  
  23. Examples labeled as 3 classified by model as 8: 4 times  
  24. Examples labeled as 3 classified by model as 9: 1 times  
  25. Examples labeled as 4 classified by model as 2: 1 times  
  26. Examples labeled as 4 classified by model as 4: 973 times  
  27. Examples labeled as 4 classified by model as 6: 2 times  
  28. Examples labeled as 4 classified by model as 7: 1 times  
  29. Examples labeled as 4 classified by model as 9: 5 times  
  30. Examples labeled as 5 classified by model as 0: 2 times  
  31. Examples labeled as 5 classified by model as 3: 4 times  
  32. Examples labeled as 5 classified by model as 5: 882 times  
  33. Examples labeled as 5 classified by model as 6: 1 times  
  34. Examples labeled as 5 classified by model as 7: 1 times  
  35. Examples labeled as 5 classified by model as 8: 2 times  
  36. Examples labeled as 6 classified by model as 0: 4 times  
  37. Examples labeled as 6 classified by model as 1: 2 times  
  38. Examples labeled as 6 classified by model as 4: 1 times  
  39. Examples labeled as 6 classified by model as 5: 4 times  
  40. Examples labeled as 6 classified by model as 6: 945 times  
  41. Examples labeled as 6 classified by model as 8: 2 times  
  42. Examples labeled as 7 classified by model as 1: 5 times  
  43. Examples labeled as 7 classified by model as 2: 3 times  
  44. Examples labeled as 7 classified by model as 3: 1 times  
  45. Examples labeled as 7 classified by model as 7: 1016 times  
  46. Examples labeled as 7 classified by model as 8: 1 times  
  47. Examples labeled as 7 classified by model as 9: 2 times  
  48. Examples labeled as 8 classified by model as 0: 1 times  
  49. Examples labeled as 8 classified by model as 3: 1 times  
  50. Examples labeled as 8 classified by model as 5: 2 times  
  51. Examples labeled as 8 classified by model as 7: 2 times  
  52. Examples labeled as 8 classified by model as 8: 966 times  
  53. Examples labeled as 8 classified by model as 9: 2 times  
  54. Examples labeled as 9 classified by model as 3: 1 times  
  55. Examples labeled as 9 classified by model as 4: 2 times  
  56. Examples labeled as 9 classified by model as 5: 4 times  
  57. Examples labeled as 9 classified by model as 6: 1 times  
  58. Examples labeled as 9 classified by model as 7: 5 times  
  59. Examples labeled as 9 classified by model as 8: 3 times  
  60. Examples labeled as 9 classified by model as 9: 993 times  
  61.   
  62.   
  63. ==========================Scores========================================  
  64.  Accuracy:        0.9901  
  65.  Precision:       0.99  
  66.  Recall:          0.99  
  67.  F1 Score:        0.99  
  68. ========================================================================  
  69. [main] INFO cv.LenetMnistExample - ****************Example finished********************  

    因为图传不上去,我就直接粘帖了结果。从中我们看到最终的一个准确率,还有就是哪些图片是分类正确的,哪些是分类错误的。当然我们可以通过增加训练的轮次还有调超参数来进一步优化,不过实际上这样的结果已经可以拿到生产上去用了。

    总结一下。其实包括我自己在内的很多人都对深度学习不了解,记得当时看csdn上写的有关深度学习的博客的时候,都觉得自己不可能达到那种水平。但其实,我们都忽略了一点,深度学习自身再复杂,它也是一个算法模型,也是一种机器学习。虽然它比感知机、逻辑回归等模型复杂很多(其实逻辑回归可看作神经网络中的一个神经元,充当的是激励函数的作用,类似的激励函数很多,如tanh,relu等),但终究用它的目的依然是完成回归、分类、压缩数据等任务。所以第一步尝试还是挺重要的。当然,我们不可能从复杂的模型开始,一开始就跟上当下最流行的模型,所以就从Mnist识别的例子开始,找找感觉。以后会写一些用Deeplearning4j在Spark的案例,也还是从Mnist开始。分享的同时自己也复习一下。。。


0 0
原创粉丝点击