DeepLearning4J入门——使用LSTM进行大盘回归

来源:互联网 发布:淘宝店铺pc端装修视频 编辑:程序博客网 时间:2024/06/05 15:53

       LSTM是递归神经网络(RNN)的一个变种,相较于RNN而言,解决了记忆消失的问题,用来处理序列问题是一个很好的选择。本文主要介绍如何使用DL4J中的LSTM来执行回归分析。如果不清楚RNN和LSTM,可以先阅读 LSTM和递归网络教程 以及 通过DL4J使用递归网络 ,特别是不熟悉RNN输入和预测方式的强烈建议先阅读这两个教程。如果不太会建立DL4J的工程,建议在其样例工程中进行本实验。

      言归正传,文本通过使用 LSTM对上证指数历史数据进行回归学习,并给出一个初始序列预测之后20天的大盘收盘价格来演示如何使用LSTM处理简单的序列回归问题。首先是准备数据,可以下载例子中我使用的数据集。那么接下来的问题就分成如下几步:

1. 读入训练数据,并处理成一个DataIterator;

2. 构建一个LSTM的递归神经网络;

3. 迭代训练,并输出预测结果;

4. 调参和优化。


一.处理训练数据

我们的数据是上证指数每个交易日的基本数据,格式为:

股票代码 日期开盘价 收盘价最高价  最低价成交量  成交额涨跌幅

        这个文件中的数据是倒序的,也就是说新的数据在最前面,因此在读取数据时需要做一次倒转。我将读取文件的方法放在Dataiterator中。DL4J给出了序列数据处理的DataIterator,但是在本例中我们是自己实现一个DataIterator。代码如下:

package edu.zju.cst.krselee.example.stock;import org.deeplearning4j.datasets.iterator.DataSetIterator;import org.nd4j.linalg.api.ndarray.INDArray;import org.nd4j.linalg.dataset.DataSet;import org.nd4j.linalg.dataset.api.DataSetPreProcessor;import org.nd4j.linalg.factory.Nd4j;import java.io.BufferedReader;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;import java.util.Collections;import java.util.List;import java.util.NoSuchElementException;/** * Created by kexi.lkx on 2016/8/23. */public class StockDataIterator  implements DataSetIterator {    private static final int VECTOR_SIZE = 6;    //每批次的训练数据组数    private int batchNum;    //每组训练数据长度(DailyData的个数)    private int exampleLength;    //数据集    private List<DailyData> dataList;    //存放剩余数据组的index信息    private List<Integer> dataRecord;    private double[] maxNum;    /**     * 构造方法     * */    public StockDataIterator(){        dataRecord = new ArrayList<>();    }    /**     * 加载数据并初始化     * */    public boolean loadData(String fileName, int batchNum, int exampleLength){        this.batchNum = batchNum;        this.exampleLength = exampleLength;        maxNum = new double[6];        //加载文件中的股票数据        try {            readDataFromFile(fileName);        }catch (Exception e){            e.printStackTrace();            return false;        }        //重置训练批次列表        resetDataRecord();        return true;    }    /**     * 重置训练批次列表     * */    private void resetDataRecord(){        dataRecord.clear();        int total = dataList.size()/exampleLength+1;        for( int i=0; i<total; i++ ){            dataRecord.add(i * exampleLength);        }    }    /**     * 从文件中读取股票数据     * */    public List<DailyData> readDataFromFile(String fileName) throws IOException{        dataList = new ArrayList<>();        FileInputStream fis = new FileInputStream(fileName);        BufferedReader in = new BufferedReader(new InputStreamReader(fis,"UTF-8"));        String line = in.readLine();        for(int i=0;i<maxNum.length;i++){            maxNum[i] = 0;        }        System.out.println("读取数据..");        while(line!=null){            String[] strArr = line.split(",");            if(strArr.length>=7) {                DailyData data = new DailyData();                //获得最大值信息,用于归一化                double[] nums = new double[6];                for(int j=0;j<6;j++){                    nums[j] = Double.valueOf(strArr[j+2]);                    if( nums[j]>maxNum[j] ){                        maxNum[j] = nums[j];                    }                }                //构造data对象                data.setOpenPrice(Double.valueOf(nums[0]));                data.setCloseprice(Double.valueOf(nums[1]));                data.setMaxPrice(Double.valueOf(nums[2]));                data.setMinPrice(Double.valueOf(nums[3]));                data.setTurnover(Double.valueOf(nums[4]));                data.setVolume(Double.valueOf(nums[5]));                dataList.add(data);            }            line = in.readLine();        }        in.close();        fis.close();        System.out.println("反转list...");        Collections.reverse(dataList);        return dataList;    }    public double[] getMaxArr(){        return this.maxNum;    }    public void reset(){        resetDataRecord();    }    public boolean hasNext(){        return dataRecord.size() > 0;    }    public DataSet next(){        return next(batchNum);    }    /**     * 获得接下来一次的训练数据集     * */    public DataSet next(int num){        if( dataRecord.size() <= 0 ) {            throw new NoSuchElementException();        }        int actualBatchSize = Math.min(num, dataRecord.size());        int actualLength = Math.min(exampleLength,dataList.size()-dataRecord.get(0)-1);        INDArray input = Nd4j.create(new int[]{actualBatchSize,VECTOR_SIZE,actualLength}, 'f');        INDArray label = Nd4j.create(new int[]{actualBatchSize,1,actualLength}, 'f');        DailyData nextData = null,curData = null;        //获取每批次的训练数据和标签数据        for(int i=0;i<actualBatchSize;i++){            int index = dataRecord.remove(0);            int endIndex = Math.min(index+exampleLength,dataList.size()-1);            curData = dataList.get(index);            for(int j=index;j<endIndex;j++){                //获取数据信息                nextData = dataList.get(j+1);                //构造训练向量                int c = endIndex-j-1;                input.putScalar(new int[]{i, 0, c}, curData.getOpenPrice()/maxNum[0]);                input.putScalar(new int[]{i, 1, c}, curData.getCloseprice()/maxNum[1]);                input.putScalar(new int[]{i, 2, c}, curData.getMaxPrice()/maxNum[2]);                input.putScalar(new int[]{i, 3, c}, curData.getMinPrice()/maxNum[3]);                input.putScalar(new int[]{i, 4, c}, curData.getTurnover()/maxNum[4]);                input.putScalar(new int[]{i, 5, c}, curData.getVolume()/maxNum[5]);                //构造label向量                label.putScalar(new int[]{i, 0, c}, nextData.getCloseprice()/maxNum[1]);                curData = nextData;            }            if(dataRecord.size()<=0) {                break;            }        }        return new DataSet(input, label);    }    public int batch() {        return batchNum;    }    public int cursor() {        return totalExamples() - dataRecord.size();    }    public int numExamples() {        return totalExamples();    }    public void setPreProcessor(DataSetPreProcessor preProcessor) {        throw new UnsupportedOperationException("Not implemented");    }    public int totalExamples() {        return (dataList.size()) / exampleLength;    }    public int inputColumns() {        return dataList.size();    }    public int totalOutcomes() {        return 1;    }    @Override    public List<String> getLabels() {        throw new UnsupportedOperationException("Not implemented");    }    @Override    public void remove() {        throw new UnsupportedOperationException();    }}
      StockDataIterator实现了DataIterator接口,于是需要实现几个必须的方法,例如hasNext、next、reset……用来进行每一批次DataSet的获取,loadData和readDataFromFile用来获取数据,并保存在一个DailyData类型的List中,每次调用next方法时,就会从List取出当前需要的数据,并构造成DataSet,返回给调用者。DailyData的实现如下:
package edu.zju.cst.krselee.example.stock;/** * Created by kexi.lkx on 2016/8/23. */public class DailyData {    //开盘价    private double openPrice;    //收盘价    private double closeprice;    //最高价    private double maxPrice;    //最低价    private double minPrice;    //成交量    private double turnover;    //成交额    private double volume;    public double getTurnover() {        return turnover;    }    public double getVolume() {        return volume;    }    public DailyData(){    }    public double getOpenPrice() {        return openPrice;    }    public double getCloseprice() {        return closeprice;    }    public double getMaxPrice() {        return maxPrice;    }    public double getMinPrice() {        return minPrice;    }    public void setOpenPrice(double openPrice) {        this.openPrice = openPrice;    }    public void setCloseprice(double closeprice) {        this.closeprice = closeprice;    }    public void setMaxPrice(double maxPrice) {        this.maxPrice = maxPrice;    }    public void setMinPrice(double minPrice) {        this.minPrice = minPrice;    }    public void setTurnover(double turnover) {        this.turnover = turnover;    }    public void setVolume(double volume) {        this.volume = volume;    }    @Override    public String toString(){        StringBuilder builder = new StringBuilder();        builder.append("开盘价="+this.openPrice+", ");        builder.append("收盘价="+this.closeprice+", ");        builder.append("最高价="+this.maxPrice+", ");        builder.append("最低价="+this.minPrice+", ");        builder.append("成交量="+this.turnover+", ");        builder.append("成交额="+this.volume);        return builder.toString();    }}
代码中对数据的各个维度进行了归一化处理,方法是记录每个维度的最大值,构造特征向量与标签时用原始数值除以最大值,得到0-1之间的数,归一化的好处在于使训练过程收敛变快,读者也可以试试不归一化的情况,比较两者的差别。

二.构建LSTM网络

       本例中我构造了一个两个隐含层的LSTM网络,隐含层激活函数是tanh,输出层使用identity函数来执行回归。输入单元数为6,因为单个向量是6维的(开盘价、收盘价、最高价、最低价、成交量、成交额);输出单元数为1,用于预测第二天收盘价,代码如下:
    private static final int IN_NUM = 6;    private static final int OUT_NUM = 1;    private static final int Epochs = 100;    private static final int lstmLayer1Size = 50;    private static final int lstmLayer2Size = 100;    public static MultiLayerNetwork getNetModel(int nIn,int nOut){        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)            .learningRate(0.1)            .rmsDecay(0.5)            .seed(12345)            .regularization(true)            .l2(0.001)            .weightInit(WeightInit.XAVIER)            .updater(Updater.RMSPROP)            .list()            .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(lstmLayer1Size)                .activation("tanh").build())            .layer(1, new GravesLSTM.Builder().nIn(lstmLayer1Size).nOut(lstmLayer2Size)                .activation("tanh").build())            .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation("identity")                .nIn(lstmLayer2Size).nOut(nOut).build())            .pretrain(false).backprop(true)            .build();        MultiLayerNetwork net = new MultiLayerNetwork(conf);        net.init();        net.setListeners(new ScoreIterationListener(1));        return net;    }

这段代码中有很多参数可以进行调整来寻找最优的拟合效果或调整训练速率,比如隐含层单元数目、激活函数、学习速率、正则化因子……构造好网络后加入一个ScoreIterationListener来监听每次迭代训练后的得分。

三.执行迭代训练

第二部分里面我们设置了完整训练集的迭代次数Epochs为100,表示用整个数据集反复训练100次,训练部分代码如下:
 public static void train(MultiLayerNetwork net,StockDataIterator iterator){        //迭代训练        for(int i=0;i<Epochs;i++) {            DataSet dataSet = null;            while (iterator.hasNext()) {                dataSet = iterator.next();                net.fit(dataSet);            }            iterator.reset();            System.out.println();            System.out.println("=================>完成第"+i+"次完整训练");            INDArray initArray = getInitArray(iterator);            System.out.println("预测结果:");            for(int j=0;j<20;j++) {                INDArray output = net.rnnTimeStep(initArray);                System.out.print(output.getDouble(0)*iterator.getMaxArr()[1]+" ");            }            System.out.println();            net.rnnClearPreviousState();        }    }    private static INDArray getInitArray(StockDataIterator iter){        double[] maxNums = iter.getMaxArr();        INDArray initArray = Nd4j.zeros(1, 6, 1);        initArray.putScalar(new int[]{0,0,0}, 3433.85/maxNums[0]);        initArray.putScalar(new int[]{0,1,0}, 3445.41/maxNums[1]);        initArray.putScalar(new int[]{0,2,0}, 3327.81/maxNums[2]);        initArray.putScalar(new int[]{0,3,0}, 3470.37/maxNums[3]);        initArray.putScalar(new int[]{0,4,0}, 304197903.0/maxNums[4]);        initArray.putScalar(new int[]{0,5,0}, 3.8750365e+11/maxNums[5]);        return initArray;    }
      每当进行一次完整集的训练之后,我们初始化了一个初始序列进行预测之后20个序列的输出。整个程序主函数如下:

    public static void main(String[] args) {        String inputFile = StockRnnPredict.class.getClassLoader().getResource("stock/sh000001.csv").getPath();        int batchSize = 1;        int exampleLength = 30;        //初始化深度神经网络        StockDataIterator iterator = new StockDataIterator();        iterator.loadData(inputFile,batchSize,exampleLength);        MultiLayerNetwork net = getNetModel(IN_NUM,OUT_NUM);        train(net, iterator);    }


      迭代100次后,得到的输出序列如下:

3489.9679512619973 3516.991701169014 3510.4443733012677 3490.410951650143 3476.138713735342 3469.275475754738 3466.278687063456 3464.9017547094822 3464.2161934530736 3463.8574357616903 3463.670068384409 3463.582194536925 3463.5545977914335 3463.5658543586733 3463.6010765206815 3463.650460170508 3463.7067430067063 3463.764115188122 3463.8196717941764 3463.8705079042916 

四.进一步优化

本文主要介绍LSTM的使用方法,并不是真的说如此就能准确的预测大盘走势了(当然,网络是有可能真的学习到一些大盘走势特征进行),想要做到这一点需要对本例进行许多调整,比如获取更全面的每日大盘信息,选取更多合适的维度来构建特征向量,当然也可以调整预测值,不仅仅是预测收盘价而已。另外可以调整在第二部分提到的那些参数。





1 0
原创粉丝点击