deeplearning4j例程(一) CSVExample

来源:互联网 发布:备份通讯录的软件 编辑:程序博客网 时间:2024/06/06 09:02

        这个例程比较简单,写这篇博客主要时为了做一些简单的记录,以防止后面遇到浪费不必要的时间。

这个例程包含读入CSV数据,对数据进行归一化处理,然后创建简单的神经网络,训练然后预测。 

package org.deeplearning4j.examples.dataExamples;import org.datavec.api.records.reader.RecordReader;import org.datavec.api.records.reader.impl.csv.CSVRecordReader;import org.datavec.api.split.FileSplit;import org.datavec.api.util.ClassPathResource;import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;import org.deeplearning4j.eval.Evaluation;import org.deeplearning4j.nn.conf.MultiLayerConfiguration;import org.deeplearning4j.nn.conf.NeuralNetConfiguration;import org.deeplearning4j.nn.conf.layers.DenseLayer;import org.deeplearning4j.nn.conf.layers.OutputLayer;import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;import org.deeplearning4j.nn.weights.WeightInit;import org.deeplearning4j.optimize.listeners.ScoreIterationListener;import org.nd4j.linalg.activations.Activation;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.dataset.SplitTestAndTrain;import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;import org.nd4j.linalg.lossfunctions.LossFunctions;import org.slf4j.Logger;import org.slf4j.LoggerFactory;/** * @author Adam Gibson */public class CSVExample {    private static Logger log = LoggerFactory.getLogger(CSVExample.class); 创建log,便于打印日志    public static void main(String[] args) throws  Exception {        //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing        int numLinesToSkip = 0;  有些文件具有表头,有些没有。即读取文件时需要跳过的行数        String delimiter = ",";       数据之间的分隔符        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);    文件读取器        recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));      从磁盘读取文件        //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network       
        //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
  int labelIndex = 4; //label所在的位置,
        //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
int numClasses = 3; 分多少类
        //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
int batchSize = 150;数据共有多少条?还是要批处理的数量? //将数据存入迭代器,参数分别为:读取器 批处理的量 label的位置 分多少类 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); 将数据转为DataSet格式 allData.shuffle(); 混洗,打乱数据 //分成训练集和测试集 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training DataSet trainingData = testAndTrain.getTrain(); 获得训练集 DataSet testData = testAndTrain.getTest(); 获得测试集 System.out.println("allData = "+allData.numExamples()+" train = "+trainingData.numExamples()); //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance): DataNormalization normalizer = new NormalizerStandardize(); 对数据进行归一化
        //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.fit(trainingData); 计算训练集的均值和方差
normalizer.transform(trainingData); 对训练集进行归一化
normalizer.transform(testData); 利用训练集的数据对测试集进行归一化
final int numInputs = 4; 输入数据的维度
int outputNum = 3; 分类的个数
 int iterations = 1000; 迭代次数
long seed = 6; 随机数
log.info("Build model...."); 配置网络结构
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed) .iterations(iterations) .activation(Activation.TANH) 激活函数为双曲正切
.weightInit(WeightInit.XAVIER) 权重初始化
.learningRate(0.1) 学习率
.regularization(true).l2(1e-4) l2正则化
.list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) 第一层输入为4个节点,输出为3个
.build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) 输入为3个输出为3个
.build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) 激活函数为softmax
.nIn(3).nOut(outputNum).build()) .backprop(true).pretrain(false) 反向传播
.build();
//run the model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init();
model.setListeners(new ScoreIterationListener(100)); 每迭代100次,输出一次日志
model.fit(trainingData); 开始训练 //evaluate the model on the test set
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatureMatrix()); 获得输入数据的特征值,并计算预测值
eval.eval(testData.getLabels(), output); 评估原始label与预测的predict
log.info(eval.stats()); 打印日志 }}





结果如下:



如有问题,请批评指正。谢谢

0 0