deeplearning4j实现多感知器的手写数字识别
来源:互联网 发布:劫的面具淘宝 编辑:程序博客网 时间:2024/05/20 23:35
package com.itcast.wang.test_dl;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.eval.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
/**A Simple MLP applied to digit classification for MNIST.
*/
public class MLPMnistSingleLayerExample {
private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);
public static void main(String[] args) throws Exception {
final int numRows = 28;//图像宽
final int numColumns = 28;//图像长
int outputNum = 10;//输出的类别数
int batchSize = 128;//没128个样本参加训练
int rngSeed = 123;//
int numEpochs = 15;//训练集样本迭代的次数
//Get the DataSetIterators:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnist = new MnistDataSetIterator(batchSize, false, rngSeed);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)//随机梯度下降
.iterations(1)
.learningRate(0.006)//学习率
.updater(Updater.NESTEROVS).momentum(0.9)//运动惯量
.regularization(true).l2(1e-4)//是否使用正则化
.list()
.layer(0, new DenseLayer.Builder()//第一层网络配置
.nIn(numRows * numColumns)//输入数目
.nOut(1000)//输出数目
.activation("relu")//激活函数 relu
.weightInit(WeightInit.XAVIER)//权值初始化
.build())
//输出层指定误差函数
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)//误差函数
.nIn(1000)//输入
.nOut(outputNum)//输出
.activation("softmax")//激活函数
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
for(int i=0;i<4;i++){
model.fit(mnistTrain);
System.out.println(" Completed epoch is :" + i);
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while(mnist.hasNext()){
DataSet ds = mnist.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
System.out.println(eval.stats());
mnist.reset();
}
System.out.println("model finish");
}
}
阅读全文
0 0
- deeplearning4j实现多感知器的手写数字识别
- 多层感知机实现mnist手写数字识别
- Deeplearning4j 实战(2):Deeplearning4j 手写体数字识别Spark实现
- 第一章 用神经网络识别手写数字(第一节 感知器)
- tensorflow实战之三:MNIST手写数字识别的优化2-多层感知器
- 手写数字识别实现
- knn算法实现的数字手写识别
- cnn 识别手写数字的实现
- TensorFlow实现识别手写数字
- cnn实现手写数字识别
- Deeplearning4j 实战(2):Deeplearning4j 手写体数字识别Spark实现【转】
- 深度学习Deeplearning4j 入门实战(2):Deeplearning4j 手写体数字识别Spark实现
- 【TensorFlow-windows】(三) 多层感知器进行手写数字识别(mnist)
- 手写数字识别的几种实现方法
- 基于K-近邻算法识别手写数字的实现
- 机器学习-kNN实现简单的手写数字识别系统
- Tensoflow+CNN实现简单的mnist手写数字识别
- 利用贝叶斯分类器实现手写数字识别
- ArrayBlockingQueue
- HDU 6191 01树合并
- centos 6.5 安装mysql-5.7.18-linux-glibc2.5-x86_64
- 17.8.31 日报
- git
- deeplearning4j实现多感知器的手写数字识别
- 年总结
- linux 修改最大文件数,进程数
- css选择器
- UISegmentedControl用法详解
- bug代码// 求最大公因数、最小公倍数
- 使用PrepareStatement
- [总结]年中总结
- Servlet生命周期 HttpServlet Mapping配置