deeplearning4j之卷积神经网络实现

来源:互联网 发布:getting windows ready 编辑:程序博客网 时间:2024/06/02 02:07

卷积神经网络从跟普通的的机器学习模型并不大一样,输入一般为一个高维矩阵,然后经过卷积、池化、卷积、池化、、到全连接(从矩阵转化一个向量)、softmax、方向传播 调整权值,

目前实现cnn的各种深度学习架构很多,下面用的deeplearning4j包实现,主要参考git上项目提供的例子,


package com.meituan.deeplearning4j;import java.io.IOException;import java.util.HashMap;import java.util.Map;import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;import org.deeplearning4j.eval.Evaluation;import org.deeplearning4j.nn.api.OptimizationAlgorithm;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.Updater;import org.deeplearning4j.nn.conf.inputs.InputType;import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.api.DataSet;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.lossfunctions.LossFunctions;public class LenetMnistExample {public static void main(String[] args) throws IOException {int nChannels = 1;int outputNum = 10;int batchSize = 64;int nEpochs = 1;int iterations = 1;int seed = 123;System.out.println("load data");DataSetIterator mnisTrain = new MnistDataSetIterator(batchSize, true,12345);DataSetIterator mnisTest = new MnistDataSetIterator(batchSize, false,12345);System.out.println("Builder model..");Map<Integer, Double> lrSchedule = new HashMap<Integer, Double>();System.out.println("build model....");MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0,new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2,new ConvolutionLayer.Builder(5, 5)// Note that nIn need not be specified in later// layers.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4,new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5,new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)) // See// note// below.backprop(true).pretrain(false);MultiLayerConfiguration conf=builder.build();MultiLayerNetwork  model=new MultiLayerNetwork(conf);model.init();System.out.println("train model is start....");for(int i=0;i<4;i++){model.fit(mnisTrain);System.out.println(" Completed epoch is :" + i);            System.out.println("Evaluate model....");            Evaluation eval = new Evaluation(outputNum);            while(mnisTest.hasNext()){                DataSet ds = mnisTest.next();                INDArray output = model.output(ds.getFeatureMatrix(), false);                eval.eval(ds.getLabels(), output);            }            System.out.println(eval.stats());            mnisTest.reset();}System.out.println("model finish");}}



cnn这些东西不用gpu什么的,训练速度确实很慢

迭代4词的结果,准确率达到了如下,把0分为0的情况有974个,0分为1的有1个。。。。:


Examples labeled as 0 classified by model as 0: 974 times

Examples labeled as 0 classified by model as 1: 1 times

Examples labeled as 0 classified by model as 6: 1 times

Examples labeled as 0 classified by model as 7: 2 times

Examples labeled as 0 classified by model as 8: 1 times

Examples labeled as 0 classified by model as 9: 1 times

Examples labeled as 1 classified by model as 1: 1124 times

Examples labeled as 1 classified by model as 2: 4 times

Examples labeled as 1 classified by model as 3: 2 times

Examples labeled as 1 classified by model as 5: 1 times

Examples labeled as 1 classified by model as 6: 2 times

Examples labeled as 1 classified by model as 7: 1 times

Examples labeled as 1 classified by model as 8: 1 times

Examples labeled as 2 classified by model as 0: 2 times

Examples labeled as 2 classified by model as 2: 1027 times

Examples labeled as 2 classified by model as 6: 1 times

Examples labeled as 2 classified by model as 7: 2 times

Examples labeled as 3 classified by model as 0: 1 times

Examples labeled as 3 classified by model as 2: 2 times

Examples labeled as 3 classified by model as 3: 999 times

Examples labeled as 3 classified by model as 5: 3 times

Examples labeled as 3 classified by model as 7: 2 times

Examples labeled as 3 classified by model as 8: 3 times

Examples labeled as 4 classified by model as 2: 1 times

Examples labeled as 4 classified by model as 4: 975 times

Examples labeled as 4 classified by model as 6: 2 times

Examples labeled as 4 classified by model as 9: 4 times

Examples labeled as 5 classified by model as 0: 2 times

Examples labeled as 5 classified by model as 3: 5 times

Examples labeled as 5 classified by model as 5: 878 times

Examples labeled as 5 classified by model as 6: 2 times

Examples labeled as 5 classified by model as 7: 1 times

Examples labeled as 5 classified by model as 8: 3 times

Examples labeled as 5 classified by model as 9: 1 times

Examples labeled as 6 classified by model as 0: 4 times

Examples labeled as 6 classified by model as 1: 2 times

Examples labeled as 6 classified by model as 4: 1 times

Examples labeled as 6 classified by model as 5: 5 times

Examples labeled as 6 classified by model as 6: 944 times

Examples labeled as 6 classified by model as 8: 2 times

Examples labeled as 7 classified by model as 1: 4 times

Examples labeled as 7 classified by model as 2: 8 times

Examples labeled as 7 classified by model as 3: 1 times

Examples labeled as 7 classified by model as 7: 1010 times

Examples labeled as 7 classified by model as 8: 1 times

Examples labeled as 7 classified by model as 9: 4 times

Examples labeled as 8 classified by model as 0: 4 times

Examples labeled as 8 classified by model as 2: 3 times

Examples labeled as 8 classified by model as 3: 1 times

Examples labeled as 8 classified by model as 5: 1 times

Examples labeled as 8 classified by model as 7: 2 times

Examples labeled as 8 classified by model as 8: 959 times

Examples labeled as 8 classified by model as 9: 4 times

Examples labeled as 9 classified by model as 1: 2 times

Examples labeled as 9 classified by model as 2: 1 times

Examples labeled as 9 classified by model as 3: 2 times

Examples labeled as 9 classified by model as 4: 1 times

Examples labeled as 9 classified by model as 5: 4 times

Examples labeled as 9 classified by model as 7: 3 times

Examples labeled as 9 classified by model as 8: 2 times

Examples labeled as 9 classified by model as 9: 994 times



==========================Scores========================================

 Accuracy:        0.9884

 Precision:       0.9884

 Recall:          0.9883

 F1 Score:        0.9883

========================================================================

model finish




0 0