DL4J中函数拟合程序的结构

来源:互联网 发布:java文件转md5 编辑:程序博客网 时间:2024/05/22 06:19

下面以函数拟合为例,说明DL4J的程序结构。参考源代码:

org.deeplearning4j.examples.feedforward.regression.RegressionMathFunctions

1.生成数据

1.1自变量

//生成一维向量,共nSamples个值,范围在区间[-10, 10]中//nSamples 为样本数量,官方例子中默认1000final INDArray x = Nd4j.linspace(-10,10,nSamples).reshape(nSamples, 1);

1.2因变量

//计算sin(x),fn=sin(x)final DataSetIterator iterator = getTrainingData(x,fn,batchSize,rng)//函数getTrainingData()定义如下:/** Create a DataSetIterator for training * @param x X values * @param function Function to evaluate * @param batchSize Batch size (number of examples for every call of DataSetIterator.next()) * @param rng Random number generator (for repeatability) */private static DataSetIterator getTrainingData(final INDArray x, final MathFunction function, final int batchSize, final Random rng) {    final INDArray y = function.getFunctionValues(x);    final DataSet allData = new DataSet(x,y);    final List<DataSet> list = allData.asList();    Collections.shuffle(list,rng);    return new ListDataSetIterator(list,batchSize);}

2.配置神经网络

//在主函数创建多层神经网络final MultiLayerNetwork net = new MultiLayerNetwork(conf);net.init();net.setListeners(new ScoreIterationListener(1));private static MultiLayerConfiguration getDeepDenseLayerNetworkConfiguration() {    final int numHiddenNodes = 50;    return new NeuralNetConfiguration.Builder()        .seed(seed)        .iterations(iterations)        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)        .learningRate(learningRate)        .weightInit(WeightInit.XAVIER)        .updater(Updater.NESTEROVS).momentum(0.9)        .list()        .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)            .activation(Activation.TANH).build())        .layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)            .activation(Activation.TANH).build())        .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)            .activation(Activation.IDENTITY)            .nIn(numHiddenNodes).nOut(numOutputs).build())        .pretrain(false).backprop(true).build();}

3.训练测试

//在主函数中执行以下代码final INDArray[] networkPredictions = new INDArray[nEpochs/ plotFrequency];for( int i=0; i<nEpochs; i++ ){    iterator.reset();    net.fit(iterator);    if((i+1) % plotFrequency == 0) {        networkPredictions[i/ plotFrequency] = net.output(x, false);    }}

4.输出结果

//在主函数中执行以下代码plot(fn,x,fn.getFunctionValues(x),networkPredictions);//定义作图函数plot()private static void plot(final MathFunction function, final INDArray x, final INDArray y, final INDArray... predicted) {    final XYSeriesCollection dataSet = new XYSeriesCollection();    addSeries(dataSet,x,y,"True Function (Labels)");    for( int i=0; i<predicted.length; i++ ){        addSeries(dataSet,x,predicted[i],String.valueOf(i));    }    final JFreeChart chart = ChartFactory.createXYLineChart(        "Regression Example - " + function.getName(),     // chart title        "X",                    // x axis label        function.getName() + "(X)", // y axis label        dataSet,                  // data        PlotOrientation.VERTICAL,        true,                     // include legend        true,                      // tooltips        false                      // urls    );    final ChartPanel panel = new ChartPanel(chart);    final JFrame f = new JFrame();    f.add(panel);    f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);    f.pack();    f.setVisible(true);}

5.问题及讨论

以下问题将在后续文章中逐一讲清楚:
1. Nd4j框架下的矩阵计算,向量化
2. 多层神经网络结构及参数
3. 训练、预测的策略

原创粉丝点击