deeplearning4j之卷积神经网络实现
来源:互联网 发布:getting windows ready 编辑:程序博客网 时间:2024/06/02 02:07
卷积神经网络从跟普通的的机器学习模型并不大一样,输入一般为一个高维矩阵,然后经过卷积、池化、卷积、池化、、到全连接(从矩阵转化一个向量)、softmax、方向传播 调整权值,
目前实现cnn的各种深度学习架构很多,下面用的deeplearning4j包实现,主要参考git上项目提供的例子,
package com.meituan.deeplearning4j;import java.io.IOException;import java.util.HashMap;import java.util.Map;import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;import org.deeplearning4j.eval.Evaluation;import org.deeplearning4j.nn.api.OptimizationAlgorithm;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.Updater;import org.deeplearning4j.nn.conf.inputs.InputType;import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.api.DataSet;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.lossfunctions.LossFunctions;public class LenetMnistExample {public static void main(String[] args) throws IOException {int nChannels = 1;int outputNum = 10;int batchSize = 64;int nEpochs = 1;int iterations = 1;int seed = 123;System.out.println("load data");DataSetIterator mnisTrain = new MnistDataSetIterator(batchSize, true,12345);DataSetIterator mnisTest = new MnistDataSetIterator(batchSize, false,12345);System.out.println("Builder model..");Map<Integer, Double> lrSchedule = new HashMap<Integer, Double>();System.out.println("build model....");MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).regularization(true).l2(0.0005).learningRate(0.01).weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.NESTEROVS).momentum(0.9).list().layer(0,new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(1,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(2,new ConvolutionLayer.Builder(5, 5)// Note that nIn need not be specified in later// layers.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(3,new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(4,new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(5,new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(28, 28, 1)) // See// note// below.backprop(true).pretrain(false);MultiLayerConfiguration conf=builder.build();MultiLayerNetwork model=new MultiLayerNetwork(conf);model.init();System.out.println("train model is start....");for(int i=0;i<4;i++){model.fit(mnisTrain);System.out.println(" Completed epoch is :" + i); System.out.println("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); while(mnisTest.hasNext()){ DataSet ds = mnisTest.next(); INDArray output = model.output(ds.getFeatureMatrix(), false); eval.eval(ds.getLabels(), output); } System.out.println(eval.stats()); mnisTest.reset();}System.out.println("model finish");}}
cnn这些东西不用gpu什么的,训练速度确实很慢
迭代4词的结果,准确率达到了如下,把0分为0的情况有974个,0分为1的有1个。。。。:Examples labeled as 0 classified by model as 0: 974 times
Examples labeled as 0 classified by model as 1: 1 times
Examples labeled as 0 classified by model as 6: 1 times
Examples labeled as 0 classified by model as 7: 2 times
Examples labeled as 0 classified by model as 8: 1 times
Examples labeled as 0 classified by model as 9: 1 times
Examples labeled as 1 classified by model as 1: 1124 times
Examples labeled as 1 classified by model as 2: 4 times
Examples labeled as 1 classified by model as 3: 2 times
Examples labeled as 1 classified by model as 5: 1 times
Examples labeled as 1 classified by model as 6: 2 times
Examples labeled as 1 classified by model as 7: 1 times
Examples labeled as 1 classified by model as 8: 1 times
Examples labeled as 2 classified by model as 0: 2 times
Examples labeled as 2 classified by model as 2: 1027 times
Examples labeled as 2 classified by model as 6: 1 times
Examples labeled as 2 classified by model as 7: 2 times
Examples labeled as 3 classified by model as 0: 1 times
Examples labeled as 3 classified by model as 2: 2 times
Examples labeled as 3 classified by model as 3: 999 times
Examples labeled as 3 classified by model as 5: 3 times
Examples labeled as 3 classified by model as 7: 2 times
Examples labeled as 3 classified by model as 8: 3 times
Examples labeled as 4 classified by model as 2: 1 times
Examples labeled as 4 classified by model as 4: 975 times
Examples labeled as 4 classified by model as 6: 2 times
Examples labeled as 4 classified by model as 9: 4 times
Examples labeled as 5 classified by model as 0: 2 times
Examples labeled as 5 classified by model as 3: 5 times
Examples labeled as 5 classified by model as 5: 878 times
Examples labeled as 5 classified by model as 6: 2 times
Examples labeled as 5 classified by model as 7: 1 times
Examples labeled as 5 classified by model as 8: 3 times
Examples labeled as 5 classified by model as 9: 1 times
Examples labeled as 6 classified by model as 0: 4 times
Examples labeled as 6 classified by model as 1: 2 times
Examples labeled as 6 classified by model as 4: 1 times
Examples labeled as 6 classified by model as 5: 5 times
Examples labeled as 6 classified by model as 6: 944 times
Examples labeled as 6 classified by model as 8: 2 times
Examples labeled as 7 classified by model as 1: 4 times
Examples labeled as 7 classified by model as 2: 8 times
Examples labeled as 7 classified by model as 3: 1 times
Examples labeled as 7 classified by model as 7: 1010 times
Examples labeled as 7 classified by model as 8: 1 times
Examples labeled as 7 classified by model as 9: 4 times
Examples labeled as 8 classified by model as 0: 4 times
Examples labeled as 8 classified by model as 2: 3 times
Examples labeled as 8 classified by model as 3: 1 times
Examples labeled as 8 classified by model as 5: 1 times
Examples labeled as 8 classified by model as 7: 2 times
Examples labeled as 8 classified by model as 8: 959 times
Examples labeled as 8 classified by model as 9: 4 times
Examples labeled as 9 classified by model as 1: 2 times
Examples labeled as 9 classified by model as 2: 1 times
Examples labeled as 9 classified by model as 3: 2 times
Examples labeled as 9 classified by model as 4: 1 times
Examples labeled as 9 classified by model as 5: 4 times
Examples labeled as 9 classified by model as 7: 3 times
Examples labeled as 9 classified by model as 8: 2 times
Examples labeled as 9 classified by model as 9: 994 times
==========================Scores========================================
Accuracy: 0.9884
Precision: 0.9884
Recall: 0.9883
F1 Score: 0.9883
========================================================================
model finish
- deeplearning4j之卷积神经网络实现
- 卷积神经网络之tensorflow实现
- 卷积文本分类(gpu)实现--deeplearning4j
- deeplearning4j之GloVe实现实现
- 卷积神经网络实现
- theano卷积神经网络实现
- TensorFlow实现卷积神经网络
- Tensorflow实现卷积神经网络
- Tensorflow实现卷积神经网络
- 卷积神经网络实现过程
- Tensorflow实现卷积神经网络
- TensorFlow实现卷积神经网络
- Tensorflow实现卷积神经网络
- Tensorflow实现卷积神经网络
- TensorFlow实现卷积神经网络
- tensorflow 卷积神经网络实现
- Keras实现卷积神经网络
- Tensorflow实现卷积神经网络
- 栈和队列面试题(三)---用两个队列实现一个栈
- 如何编译安装Lua5.3,实现调用C共享库
- 动态规划系列(1)——金矿模型的理解
- centos7安装rabbitmq
- CI xlsx 文件类型错误, CI上传文件 xlsx 修改类型mimes
- deeplearning4j之卷积神经网络实现
- Python 算法
- java-基本语法
- java多态,如何理解父类引用指向子类对象
- 关于自增、自减运算符的笔记
- 关于WebHDFS与HttpFS
- windows 下杀死进程
- 3.Django-MTV
- (干货)自定义 type="range" type="radio" select的样式