Java实现LSTM和GRU做分类(以IRIS数据集为例)
来源:互联网 发布:vb取整函数函数 编辑:程序博客网 时间:2024/06/08 14:03
笔者想在JAVA项目中做机器学习的分类想使用循环神经网络的时候苦于没有找到开源的代码,最后终于找到lipiji所写的LSTM和GRU,项目GitHub链接在这:项目GitHub地址,但是这个项目的demo只是简单的做了一个文本序列的预测,无法达到自己做分类的目的,于是笔者新写了一个demo来实现分类的目的,这里所使用的数据集是Iris。Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理。Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性。可通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。(来源:百度百科)点击下载Iris数据集 没有积分的也可以自己去找不需要积分的数据集。
数据预处理:首先将数据集里的花的类别修改成0,1,2三类,然后将每类中取15条数据共45条做测试集,余下105个做训练集分别存在两个文件中。新写一个类放在com.lipiji.mllib.rnn.gru包中,这里的输出层有三个节点,代表三个类别。笔者这里采用的GRU实验,要做LSTM的话将GRU类改成Cell类即可。测试代码如下:
package com.lipiji.mllib.rnn.gru;import com.lipiji.mllib.layers.MatIniter;import com.lipiji.mllib.rnn.lstm.Cell;import com.lipiji.mllib.rnn.lstm.LSTM;import com.lipiji.mllib.utils.LossFunction;import org.jblas.DoubleMatrix;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.*;public class gruTest { public static double train_x[][] = new double[105][4]; public static double test_x[][] = new double[45][4]; public static double train_y[][] = new double[105][3]; public static double test_y[][] = new double[45][3]; private static GRU gru; public static void main(String[] args) { loadData(); int hiddenSize = 4;//隐含层数量 double lr = 0.1; gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层 for (int i = 0; i < 2000; i++) {//迭代2000次 double error = 0; double num = 0; double start = System.currentTimeMillis(); Map<String, DoubleMatrix> acts = new HashMap<>(); for (int s = 0; s < train_x.length; s++) { double newx[][] = new double[1][4]; newx[0] = train_x[s]; DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵 //System.out.println(xt.getColumns()+" "+xt.getRows()); acts.put("x" + s, xt); gru.active(s, acts); DoubleMatrix predcitYt = gru.decode(acts.get("h" + s)); acts.put("py" + s, predcitYt); double newy[][] = new double[1][3]; newy[0] = train_y[s]; DoubleMatrix trueYt = new DoubleMatrix(newy); acts.put("y" + s, trueYt); if(predcitYt.argmax()!=trueYt.argmax()) error++; // bptt num ++; } gru.bptt(acts, train_x.length-1, lr); System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s"); }//结束迭代 //开始测试 int num = 0,error = 0; Map<String, DoubleMatrix> acts = new HashMap<>(); for(int s = 0; s<test_x.length;s++){ double newx[][] = new double[1][4]; newx[0] = test_x[s]; DoubleMatrix xt = new DoubleMatrix(newx); acts.put("x" + s, xt); gru.active(s, acts); DoubleMatrix predcitYt = gru.decode(acts.get("h" + s)); acts.put("py" + s, predcitYt); double newy[][] = new double[1][3]; newy[0] = test_y[s]; DoubleMatrix trueYt = new DoubleMatrix(newy); acts.put("y" + s, trueYt); if(predcitYt.argmax()!=trueYt.argmax()) error++; // bptt num ++; } System.out.println("错误数:"+error+"/"+num); } public static void loadData(){ List<String> list = readFileForList("data/train.txt");//训练集 for(int i = 0;i<list.size();i++){ String str[] = list.get(i).split(","); for(int k = 0 ; k < 4;k++) train_x[i][k]=Double.valueOf(str[k]); train_y[i][Integer.valueOf(str[4])] = 1;//将所属类别设置为1 } list = readFileForList("data/test.txt");//测试集 for(int i = 0;i<list.size();i++){ String str[] = list.get(i).split(","); for(int k = 0 ; k < 4;k++) test_x[i][k]=Double.valueOf(str[k]); test_y[i][Integer.valueOf(str[4])] = 1; } } public static List<String> readFileForList(String fileName) {//读取文件到list File file = new File(fileName); BufferedReader reader = null; List<String> s = new ArrayList<String>(); try { reader = new BufferedReader(new FileReader(file)); String tempString = null; while ((tempString = reader.readLine()) != null) { // 显示行号 s.add(tempString); } reader.close(); } catch (IOException e) { e.printStackTrace(); } finally { if (reader != null) { try { reader.close(); } catch (IOException e1) { } } return s; } }}
实验最终结果如下,可以看到45个测试集对了44个,笔者这里怎么调都无法达到完全的准确率,希望有做出来的可以告知一下,感谢。最后感谢lipiji提供的算法代码。
1 0
- Java实现LSTM和GRU做分类(以IRIS数据集为例)
- theano实现RNN(GRU和LSTM)
- R语言实现分层抽样(Stratified Sampling)以iris数据集为例
- python 实现 knn分类算法 (Iris 数据集)
- Java 实现 BP 神经网络完成 Iris 数据分类
- python基础知识——数组拼接(以iris数据为例……)
- sklern使用之通用模版(以iris为数据集,knn,PCA)
- 85、使用TFLearn实现iris数据集的分类
- 【python 神经网络】BP神经网络python实现-iris数据集分类
- c#神经网络,实现对Iris数据集进行分类
- 自己实现LSTM和GRU内部的代码
- LSTM 和GRU的区别
- LSTM神经网络 和 GRU神经网络
- RNN学习笔记(六)-GRU,LSTM 代码实现
- 数据分类流程(以titanic分类为例)
- 数据挖掘-K-近邻分类器-Iris数据集分析-根据花萼长宽分类-以散点图显示(一)
- 数据挖掘-K-近邻分类器-Iris数据集分析-根据花瓣长宽分类-以散点图显示(二)
- 利用BP神经网络分类iris数据集
- Yii2.0 场景的简单使用
- 51nod 2级算法题-1119
- SSM框架本地测试没有问题,线上报错问题解决方案
- HQL(Hive query language)常用语句
- PostgreSQL递归查询
- Java实现LSTM和GRU做分类(以IRIS数据集为例)
- 说说Java中finally、final、finalize。
- WCF之调用模式
- response.write()方法将指定的字符创输出到html页面时遇到的问题
- 霍夫直线和圆检测
- LINQ的连接扩展(左连、右连、全连等)
- WWWFrom提交表单&从Web下载轻量数据
- 自定义函数练习~学生信息管理程序
- 经验分享:CSS浮动(float,clear)通俗讲解