深度学习Java类库deeplearning4j 学习笔记-MNIST手写数字分类问题
来源:互联网 发布:中国最新人口普查数据 编辑:程序博客网 时间:2024/04/29 07:16
deeplearning4j
这是一个用Java实现的深度学习类库。
网址: https://deeplearning4j.org
问题和数据集
Minist是一个每个学过机器学习的童鞋都熟悉的类库。这个数据集包含70,000个手写数字的图片。每张图片为28*28像素。其中包含60,000个训练数据和10,000个测试数据。图中给出了一些样例图片。
每个数据都包含一张图片,以及这张图片上的数字是几。我们希望得到这样一个工具,输入是一张图片,输出是识别出的这个图片的数字。
下面会用深度学习的方法对其进行训练和测试。
深度学习网络的结构
我们知道一个深度神经网络是由多个层构成的,这个案例中使用三层深度学习网络。输入层,隐含层(Hidden layer)和输出层。
输入层的输入为图片的原始像素数据,输入层的节点个数应该与输入数据的维度相关。在这个数据集中,每个图片是28*28的,所以输入层也就有28*28个节点。
输出层为数据的识别结果。因为手写输入有十个,所以输出层的结点个数应该为10个。
隐含层有多少个节点是由我们根据经验定义的,本例中定义为1000个。
使用DL4J实现这个类库
这个类库提供一种简便的方法来实现层的定义。它提供一个NeuralNetConfiguration.Builder类来配置整个神经网络,使用DenseLayer.Builder来配置每个层的信息。
上面说的三层神经网络,其实只有两层。 第一层的输入时原始数据,输出是隐含数据,第二层输入时隐含数据,输出是分类结果。
创建这个层的核心代码如下:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) //include a random seed for reproducibility // use stochastic gradient descent as an optimization algorithm .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(0.006) //specify the learning rate .updater(Updater.NESTEROVS).momentum(0.9) //specify the rate of change of the learning rate. .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization .nIn(numRows * numColumns) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(outputNum) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) //use backpropagation to adjust weights .build();
其中NeuralNetConfiguration.Builder提供很多方法来配置各种参数。
它使用seed函数配置随机数的种子。为什么要配置随机数的种子呢? 因为神经网络使用随机数来初始化每个参数的值,如果随机数种子不一样,那么初始的参数值就不确定,那么每一次执行得到的结果都可能有细微差别。设定了随机数的种子,就能丝毫不差的重复每次执行。(每次执行得到的结果完全相同),使得实验结构都是可验证的。
它使用optimizationAlgo函数指定该层使用的最优化算法,这里使用SGD梯度下降法。
iterations指定经过几次迭代,会将输出数据传递给下一层。
learningRate是学习率。
updater指定学习率的改变函数。
regularization这个函数实现规则化,防止国际和的出现。
list将上面的配置复制到每一层的配置中。
DenseLayer.Builder指定每一层的配置。这个例子中使用了2层。第一层输入为原始新昂素数据,输出为隐含数据。其输入节点个数为28*28,使用nIn函数来设定这个值,输出由nOut指定为1000个。
第二层输入为第一层的输出个数1000个,输出为10个。
activation指定激活函数 为RELU。
weightInit指定权重初始化方法。
build函数使用上面配置的信息构建一个层。
NeuralNetConfiguration.Builder的layer方法用来添加一个层。
第二个层是输出层,所以采用了SOFTMAX的激活函数。
pretrain设置预训练为不适用(false),设置backprop为使用。 最后的build根据上面的配置构建整个神经网络。
样例程序中的数据集
样例中给出了MnistDataSetIterator类用以提供数据。
//Get the DataSetIterators: DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
其中 batchSize为批次大小。为了能高效的进行训练,需要使用批次训练的方法。就是说每次训练时不适用所有数据,而是使用其中一小部分数据,下一次训练在才有第二批数据,以此类推。
第二个参数应该是指定是否为训练集。第三个参数是随机数种子。
作者和版权
作者 杨同峰 ,作者保留所有权利, 允许该文章自由转载,但请保留此版权信息。
cite: https://deeplearning4j.org/mnist-for-beginners.html
完整代码
package org.deeplearning4j.examples.feedforward.mnist;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;import org.deeplearning4j.eval.Evaluation;import org.deeplearning4j.nn.api.OptimizationAlgorithm;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.Updater;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;import org.slf4j.Logger;import org.slf4j.LoggerFactory;/**A Simple Multi Layered Perceptron (MLP) applied to digit classification for * the MNIST Dataset (http://yann.lecun.com/exdb/mnist/). * * This file builds one input layer and one hidden layer. * * The input layer has input dimension of numRows*numColumns where these variables indicate the * number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit * (relu) activation function. The weights for this layer are initialized by using Xavier initialization * (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/) * to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer. * * The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights * for this layer is also initialized using Xavier initialization. The activation function for this * layer is a softmax, which normalizes all the 10 outputs such that the normalized sums * add up to 1. The highest of these normalized values is picked as the predicted class. * */public class MLPMnistSingleLayerExample { private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class); public static void main(String[] args) throws Exception { //number of rows and columns in the input pictures final int numRows = 28; final int numColumns = 28; int outputNum = 10; // number of output classes int batchSize = 128; // batch size for each epoch int rngSeed = 123; // random number seed for reproducibility int numEpochs = 15; // number of epochs to perform //Get the DataSetIterators: DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed); DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed); log.info("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) //include a random seed for reproducibility // use stochastic gradient descent as an optimization algorithm .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .iterations(1) .learningRate(0.006) //specify the learning rate .updater(Updater.NESTEROVS).momentum(0.9) //specify the rate of change of the learning rate. .regularization(true).l2(1e-4) .list() .layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization .nIn(numRows * numColumns) .nOut(1000) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer .nIn(1000) .nOut(outputNum) .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER) .build()) .pretrain(false).backprop(true) //use backpropagation to adjust weights .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); //print the score with every 1 iteration model.setListeners(new ScoreIterationListener(1)); log.info("Train model...."); for( int i=0; i<numEpochs; i++ ){ model.fit(mnistTrain); } log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes while(mnistTest.hasNext()){ DataSet next = mnistTest.next(); INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction eval.eval(next.getLabels(), output); //check the prediction against the true class } log.info(eval.stats()); log.info("****************Example finished********************"); }}
- 深度学习Java类库deeplearning4j 学习笔记-MNIST手写数字分类问题
- 深度学习笔记5torch实现mnist手写数字识别
- 深度学习框架Caffe学习笔记(2)-MNIST手写数字识别例程
- 深度学习- 用Torch实现MNIST手写数字识别
- 深度学习笔记(四)用Torch实现MNIST手写数字识别
- 深度学习笔记(四)用Torch实现MNIST手写数字识别
- Tensorflow深度学习笔记(五)--手写数字识别-MNIST数据测试
- Tensorflow深度学习之八:再探CNN解决mnist手写数字识别问题
- caffe学习笔记4-- 手写数字mnist训练过程
- TensorFlow学习笔记(3)----CNN识别MNIST手写数字
- TensorFlow学习笔记(二)MNIST手写数字识别
- Caffe学习笔记(六):mnist手写数字识别训练实例
- 【深度学习】笔记2_caffe自带的第一个例子,Mnist手写数字识别代码,过程,网络详解
- TensorFlow学习---实现mnist手写数字识别
- 神经网络与深度学习 1.6 使用Python实现基于梯度下降算法的神经网络和MNIST数据集的手写数字分类程序
- 神经网络与深度学习 使用Python实现基于梯度下降算法的神经网络和自制仿MNIST数据集的手写数字分类可视化程序 web版本
- OpenCV机器学习:SVM分类器实现MNIST手写数字识别
- 深度学习第三天: LeNet在Python实现Mnist手写数字.md
- WPF架构分析
- Linux 内核文件系统关键数据结构
- KMP子字符串查找算法.java
- div中的内容垂直居中的五种方法
- Git版本控制使用方法入门教程
- 深度学习Java类库deeplearning4j 学习笔记-MNIST手写数字分类问题
- Hrbust 2315 Time ("科林明伦杯"哈理工第六届团队赛)
- Tolua使用笔记一:开始使用Tolua的准备工作与lua文件读取方法
- Android 调试中获取log
- redhat6.5安装pip问题及解决
- 五种常见的 PHP 设计模式
- php精度计算问题
- Android 内部存储与外部存储
- 怎样在jsp页面加载时首先执行某个js