deeplearning4j实现多感知器的手写数字识别

来源:互联网 发布:劫的面具淘宝 编辑:程序博客网 时间:2024/05/20 23:35

package com.itcast.wang.test_dl;


import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.eval.Evaluation;  
import org.nd4j.linalg.api.ndarray.INDArray;
/**A Simple MLP applied to digit classification for MNIST.
 */
public class MLPMnistSingleLayerExample {
 
    private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);
 
    public static void main(String[] args) throws Exception {
 
        final int numRows = 28;//图像宽
        final int numColumns = 28;//图像长
        int outputNum = 10;//输出的类别数
        int batchSize = 128;//没128个样本参加训练
        int rngSeed = 123;//
        int numEpochs = 15;//训练集样本迭代的次数
 
        //Get the DataSetIterators:
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        DataSetIterator mnist = new MnistDataSetIterator(batchSize, false, rngSeed);
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)//随机梯度下降
                .iterations(1)
                .learningRate(0.006)//学习率
                .updater(Updater.NESTEROVS).momentum(0.9)//运动惯量
                .regularization(true).l2(1e-4)//是否使用正则化
                .list()
                .layer(0, new DenseLayer.Builder()//第一层网络配置
                        .nIn(numRows * numColumns)//输入数目
                        .nOut(1000)//输出数目
                        .activation("relu")//激活函数 relu
                        .weightInit(WeightInit.XAVIER)//权值初始化
                        .build())
                //输出层指定误差函数
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)//误差函数
                        .nIn(1000)//输入
                        .nOut(outputNum)//输出
                        .activation("softmax")//激活函数
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .pretrain(false).backprop(true)
                .build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));
        log.info("Train model....");
        for(int i=0;i<4;i++){  
            model.fit(mnistTrain);  
            System.out.println(" Completed epoch is :" + i);  
  
                System.out.println("Evaluate model....");  
                Evaluation eval = new Evaluation(outputNum);  
                while(mnist.hasNext()){  
                    DataSet ds = mnist.next();  
                    INDArray output = model.output(ds.getFeatureMatrix(), false);  
                    eval.eval(ds.getLabels(), output);  
                }  
                System.out.println(eval.stats());  
                mnist.reset();  
              
        }  
        System.out.println("model finish");  
    }
    }
       
原创粉丝点击