deeplearning4j例程(一) CSVExample
来源:互联网 发布:备份通讯录的软件 编辑:程序博客网 时间:2024/06/06 09:02
这个例程比较简单,写这篇博客主要时为了做一些简单的记录,以防止后面遇到浪费不必要的时间。
这个例程包含读入CSV数据,对数据进行归一化处理,然后创建简单的神经网络,训练然后预测。
package org.deeplearning4j.examples.dataExamples;import org.datavec.api.records.reader.RecordReader;import org.datavec.api.records.reader.impl.csv.CSVRecordReader;import org.datavec.api.split.FileSplit;import org.datavec.api.util.ClassPathResource;import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;import org.deeplearning4j.eval.Evaluation;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.dataset.SplitTestAndTrain;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;import org.nd4j.linalg.lossfunctions.LossFunctions;import org.slf4j.Logger;import org.slf4j.LoggerFactory;/** * @author Adam Gibson */public class CSVExample { private static Logger log = LoggerFactory.getLogger(CSVExample.class); 创建log,便于打印日志 public static void main(String[] args) throws Exception { //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing int numLinesToSkip = 0; 有些文件具有表头,有些没有。即读取文件时需要跳过的行数 String delimiter = ","; 数据之间的分隔符 RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); 文件读取器 recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); 从磁盘读取文件 //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network//5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each rowint labelIndex = 4; //label所在的位置,//3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2int numClasses = 3; 分多少类//Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)int batchSize = 150;数据共有多少条?还是要批处理的数量? //将数据存入迭代器,参数分别为:读取器 批处理的量 label的位置 分多少类 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); 将数据转为DataSet格式 allData.shuffle(); 混洗,打乱数据 //分成训练集和测试集 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); 获得训练集 DataSet testData = testAndTrain.getTest(); 获得测试集 System.out.println("allData = "+allData.numExamples()+" train = "+trainingData.numExamples()); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); 对数据进行归一化//Collect the statistics (mean/stdev) from the training data. This does not modify the input datanormalizer.fit(trainingData); 计算训练集的均值和方差
normalizer.transform(trainingData); 对训练集进行归一化
normalizer.transform(testData); 利用训练集的数据对测试集进行归一化
final int numInputs = 4; 输入数据的维度
int outputNum = 3; 分类的个数
int iterations = 1000; 迭代次数
long seed = 6; 随机数
log.info("Build model...."); 配置网络结构
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed) .iterations(iterations) .activation(Activation.TANH) 激活函数为双曲正切
.weightInit(WeightInit.XAVIER) 权重初始化
.learningRate(0.1) 学习率
.regularization(true).l2(1e-4) l2正则化
.list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) 第一层输入为4个节点,输出为3个
.build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) 输入为3个输出为3个
.build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) 激活函数为softmax
.nIn(3).nOut(outputNum).build()) .backprop(true).pretrain(false) 反向传播
.build();
//run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init();
model.setListeners(new ScoreIterationListener(100)); 每迭代100次,输出一次日志
model.fit(trainingData); 开始训练 //evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatureMatrix()); 获得输入数据的特征值,并计算预测值
eval.eval(testData.getLabels(), output); 评估原始label与预测的predict
log.info(eval.stats()); 打印日志 }}
结果如下:
如有问题,请批评指正。谢谢
0 0
- deeplearning4j例程(一) CSVExample
- Deeplearning4j例程(二) 加载本地模型预测未知图像
- 学习DeepLearning4J(一、把example跑起来)
- deeplearning4j
- Deeplearning4j
- Deeplearning4j 实战(1):Deeplearning4j 手写体数字识别
- Deeplearning4j 实战(2):Deeplearning4j 手写体数字识别Spark实现
- Deeplearning4j 实战(1):Deeplearning4j 手写体数字识别【转】
- Log4J入门教程(一) 入门例程
- java例程练习(一维数组)
- IRP_完成例程(一)-返回status_success
- Log4J入门教程(一) 入门例程
- Log4J入门教程(一) 入门例程
- Creat_average_shape_model.hdev例程相关学习(一)
- 实际项目使用例程(一)
- halcon例程 -- 逐字细究(一)
- caffe学习笔记(一):MNIST例程
- Log4J入门教程(一) 入门例程
- 用OC实现一个类似java的事件监听机制
- linux内核驱动之修改wifi驱动
- Android添加一个开机完成后执行的脚本
- python os.system() 空格处理
- SSM框架搭建及源码解析--spring的BeanFactoryPostProcessor扩展(三)
- deeplearning4j例程(一) CSVExample
- css3鼠标滑过图片放大
- PHP合并两个有序数组
- Android Studio 中DDMS无法显示文件树以及data文件夹中文件无法导出解决方法
- css3写下雨效果
- [P1209]修理牛棚
- MyEclipse将项目部署到tomcat服务器上
- hibernate一对多,多对一详说
- JAVA基础知识——字符串