Java实现LSTM和GRU做分类(以IRIS数据集为例)

来源:互联网 发布:vb取整函数函数 编辑:程序博客网 时间:2024/06/08 14:03

笔者想在JAVA项目中做机器学习的分类想使用循环神经网络的时候苦于没有找到开源的代码,最后终于找到lipiji所写的LSTM和GRU,项目GitHub链接在这:项目GitHub地址,但是这个项目的demo只是简单的做了一个文本序列的预测,无法达到自己做分类的目的,于是笔者新写了一个demo来实现分类的目的,这里所使用的数据集是Iris。Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。

数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:

package com.lipiji.mllib.rnn.gru;import com.lipiji.mllib.layers.MatIniter;import com.lipiji.mllib.rnn.lstm.Cell;import com.lipiji.mllib.rnn.lstm.LSTM;import com.lipiji.mllib.utils.LossFunction;import org.jblas.DoubleMatrix;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.*;public class gruTest {    public static double train_x[][] = new double[105][4];    public static double test_x[][] = new double[45][4];    public static double train_y[][] = new double[105][3];    public static double test_y[][] = new double[45][3];    private static GRU gru;    public static void main(String[] args) {        loadData();        int hiddenSize = 4;//隐含层数量        double lr = 0.1;        gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层        for (int i = 0; i < 2000; i++) {//迭代2000次            double error = 0;            double num = 0;            double start = System.currentTimeMillis();            Map<String, DoubleMatrix> acts = new HashMap<>();            for (int s = 0; s < train_x.length; s++) {                double newx[][] = new double[1][4];                newx[0] = train_x[s];                DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵                //System.out.println(xt.getColumns()+" "+xt.getRows());                acts.put("x" + s, xt);                gru.active(s, acts);                DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));                acts.put("py" + s, predcitYt);                double newy[][] = new double[1][3];                newy[0] = train_y[s];                DoubleMatrix trueYt = new DoubleMatrix(newy);                acts.put("y" + s, trueYt);                if(predcitYt.argmax()!=trueYt.argmax())                    error++;                // bptt                num ++;            }            gru.bptt(acts, train_x.length-1, lr);            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");        }//结束迭代        //开始测试        int num = 0,error = 0;        Map<String, DoubleMatrix> acts = new HashMap<>();        for(int s = 0; s<test_x.length;s++){            double newx[][] = new double[1][4];            newx[0] = test_x[s];            DoubleMatrix xt = new DoubleMatrix(newx);            acts.put("x" + s, xt);            gru.active(s, acts);            DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));            acts.put("py" + s, predcitYt);            double newy[][] = new double[1][3];            newy[0] = test_y[s];            DoubleMatrix trueYt = new DoubleMatrix(newy);            acts.put("y" + s, trueYt);            if(predcitYt.argmax()!=trueYt.argmax())                error++;            // bptt            num ++;        }        System.out.println("错误数:"+error+"/"+num);    }    public static void loadData(){        List<String> list = readFileForList("data/train.txt");//训练集        for(int i = 0;i<list.size();i++){            String str[] = list.get(i).split(",");            for(int k = 0 ; k < 4;k++)                train_x[i][k]=Double.valueOf(str[k]);            train_y[i][Integer.valueOf(str[4])] = 1;//将所属类别设置为1        }        list = readFileForList("data/test.txt");//测试集        for(int i = 0;i<list.size();i++){            String str[] = list.get(i).split(",");            for(int k = 0 ; k < 4;k++)                test_x[i][k]=Double.valueOf(str[k]);            test_y[i][Integer.valueOf(str[4])] = 1;        }    }    public static List<String> readFileForList(String fileName) {//读取文件到list        File file = new File(fileName);        BufferedReader reader = null;        List<String> s = new ArrayList<String>();        try {            reader = new BufferedReader(new FileReader(file));            String tempString = null;            while ((tempString = reader.readLine()) != null) {                // 显示行号                s.add(tempString);            }            reader.close();        } catch (IOException e) {            e.printStackTrace();        } finally {            if (reader != null) {                try {                    reader.close();                } catch (IOException e1) {                }            }            return s;        }    }}


实验最终结果如下,可以看到45个测试集对了44个,笔者这里怎么调都无法达到完全的准确率,希望有做出来的可以告知一下,感谢。最后感谢lipiji提供的算法代码。


1 0
原创粉丝点击