Java实现手写数字的识别(BP神经网络的运用)
来源:互联网 发布:死神vs火影月改优化版 编辑:程序博客网 时间:2024/06/09 14:24
最近对机器学习方面的知识有点感兴趣,所以特别的对神经网络方面的知识进行了了解。然而,发现大部分的人都是通过Python或者R语言及其Matlab来进行实验的,而自己却还没有时间进行学习,而且对Java语言有一种情有独钟的感觉,所以特别的就用Java语言实现BP神经网络。PS:这个内容其实在Python和Matlab中都有已经封装好的库,直接调用就可以了的,而且效率也还不错。而我只是尝试试着用Java进行写写而已(最后的效果还不如已经封装好的代码),自己并没有进行过多的优化。
下面就介绍一个简单的例子,就是对手写的数字通过BP神经网络进行识别,我也刚接触机器学习,所以有不足的地方欢迎大家进行指正,共同学习。
在这里就不再对BP神经网络的原理进行过多的阐述了,如果大家对于这个方面不懂的话,可以自行百度的进行了解。
例子简介:
该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.
训练样本:训练集中包括60000个样本,因为所有数据集中28x28像素的灰度图片的尺寸为784,所以训练
集输出的格式为[60000, 784]。
一:训练数据和测试数据的格式
训练数据的格式如图所示:
这只是部分显示的效果,在这里进行讲解一下:其中每一行一共是784个输入,因为这是由一个28*28的图片进行像素转换过来的,而且每一列的数值都是在0-255之间。
标签结果:
当然,这也只是部分训练集的结果,其对应的就是上面训练集的识别结果,所以主要的作用也就是标签。
二:实验思路
(1):构建神经网络层次结构
由训练集数据可知,手写输入的数据维数为784维,而对应的输出结果为分别为0-9的10个数字,所以根据训练集的数据可知,在构建的神经网络的输入层的神经元的节点个数为784个,而对应的输出层的神经元个数为10个。由于在神经网络的隐层可以有一层或者多层,对此自己在本实验中,采取的是只采用了一层隐层结构。
(2):确定隐层中的神经元的个数
因为对于隐层的神经元个数的确定目前还没有什么比较完美的解决方案,所以对此经过自己查阅书籍和上网查阅资料,有以下的几种经验方式来确定隐层的神经元的个数,方式分别如下所示:
1) 一般取(输入+输出)/2
2) 隐层一般小于输入层
3)(输入层+1)/2
4) log(输入层)
5) log(输入层)+10
由于上述的也是由经验而得来的,所以自己在实验过程中分别的测试了一下,最后得到以第五种的方式得到的测试结果相对较高。
(3):设置神经元的激活函数
在《机器学习》的书中介绍了两种比较常用的函数,分别是阶跃函数和Sigmoid函数。最后自己采用了后者函数。
(4):初始化输入层和隐层之间神经元间的权值信息
采用的是使用简单的随机数分配的方法,并且两层之间的神经元权值是通过二维数组进行保留,数组的索引就代表着两层对应的神经元的索引信息
(5):初始化隐层和输出层之间神经元间的权值信息
采用的是使用简单的随机数分配的方法,并且两层之间的神经元权值是通过二维数组进行保留,数组的索引就代表着两层对应的神经元的索引信息
(6):读取CSV测试集表格信息,并加载到程序用数据保存,其中将每个维数的数据都换成了0和1的二进制数进行处理。
(7):读取CSV测试集结果表格信息,并加载到程序用数据保存
(8):计算输入层与隐层中隐层神经元的阈值
这里主要是采用了下面的方法:
Sum=sum+weight[i][j] * layer0[i];
参数的含义:将每个输入层中的神经元与神经元的权值信息weight[i][j]乘以对应的输入层神经元的阈值累加,然后再调用激活函数得到对应的隐层神经元的阈值。
(9):计算隐层与输出层中输出层的神经元的阈值
方法和上面的类似,只是相对应的把权值信息进行了修改即可。
(10):计算误差逆传播(输出层的逆误差)
采用书上P103页的方法(西瓜书)
(11):计算误差传播(隐层的逆误差)
采用书上P103页的方法(西瓜书)
(12):更新各层神经元之间的权值信息
double newVal = momentum * prevWeight[j][i] + eta * delta[i] * layer[j];
参数:其中设置momentum 为0.9,设置eta 为0.25,prevWeight[j][i]表示神经元之间的权值,layer[j]和delta[i]表示两层不同神经元的阈值。
(13):循环迭代训练5次
(14):输入测试集数据
(15):输出测试集预测结果和实际结果进行比较,得到精确度
好了,上面就是我大概的思路了,其实是不是很简单呢。当然这只是最基本的而已,因为我并没有进行优化,以至于后面的识别效果并不是很好。
预测结果:,最后的精确度平均值在85%左右(并不理想)。
虽然结果不是很好,但是其BP的原理还是实现了,此处鼓励鼓励。
好了,不多说了,把代码直接贴出来。(其中由于训练数据和测试数据都是CSV格式,所以在用Java时需要进行特别的读取处理。)
1)读取CSV格式的文件:
package shenjingwangluo2;/* * 读取后缀为csv的excell文件 * */import java.io.BufferedReader;import java.io.FileReader;import java.util.ArrayList;import java.util.Iterator;import java.util.List;public class CSVFileUtil { private String fileName = null; private BufferedReader br = null; private List<String> list = new ArrayList<String>(); public CSVFileUtil() { } public CSVFileUtil(String fileName) throws Exception { this.fileName = fileName; br = new BufferedReader(new FileReader(fileName)); String stemp; while ((stemp = br.readLine()) != null) { list.add(stemp); } } public List getList() { return list; } /** * 获取行数 * @return */ public int getRowNum() { return list.size(); } /** * 获取列数 * @return */ public int getColNum() { if (!list.toString().equals("[]")) { if (list.get(0).toString().contains(",")) {// csv为逗号分隔文件 return list.get(0).toString().split(",").length; } else if (list.get(0).toString().trim().length() != 0) { return 1; } else { return 0; } } else { return 0; } } /** * 获取制定行 * @param index * @return */ public String getRow(int index) { if (this.list.size() != 0) { return (String) list.get(index); } else { return null; } } /** * 获取指定列 * @param index * @return */ public String getCol(int index) { if (this.getColNum() == 0) { return null; } StringBuffer sb = new StringBuffer(); String tmp = null; int colnum = this.getColNum(); if (colnum > 1) { for (Iterator it = list.iterator(); it.hasNext();) { tmp = it.next().toString(); sb = sb.append(tmp.split(",")[index] + ","); } } else { for (Iterator it = list.iterator(); it.hasNext();) { tmp = it.next().toString(); sb = sb.append(tmp + ","); } } String str = new String(sb.toString()); str = str.substring(0, str.length() - 1); return str; } /** * 获取某个单元格 * @param row * @param col * @return */ public String getString(int row, int col) { String temp = null; int colnum = this.getColNum(); if (colnum > 1) { temp = list.get(row).toString().split(",")[col]; } else if(colnum == 1){ temp = list.get(row).toString(); } else { temp = null; } return temp; } public void CsvClose()throws Exception{ this.br.close(); }}2)BP神经网络构建
package shenjingwangluo2;import java.util.Random;public class BP {private final double[] input; //输入层 private final double[] hidden; //隐含层 private final double[] output; //输出层 private final double[] target; //预测的输出内容 private final double[] hidDelta; //隐含层的神经元的误差(每一个的) private final double[] optDelta; //输出层的神经元的误差(每一个的) private final double eta; //学习率 private final double momentum; //动量参数 private final double[][] iptHidWeights; //从输入层到隐含层的矩阵权值 private final double[][] hidOptWeights; //从隐含层到输出层的矩阵权值 private final double[][] iptHidPrevUptWeights; //更新之前的权值信息(输入层到隐含层) private final double[][] hidOptPrevUptWeights; //更细之前的权值信息(隐含层到输出层) public double optErrSum = 0d; public double hidErrSum = 0d; private final Random random; //主要是对权值采取的是随机产生的方法 //初始化 public BP(int inputSize, int hiddenSize, int outputSize, double eta, double momentum) { input = new double[inputSize + 1]; hidden = new double[hiddenSize + 1]; output = new double[outputSize + 1]; target = new double[outputSize + 1]; hidDelta = new double[hiddenSize + 1]; optDelta = new double[outputSize + 1]; iptHidWeights = new double[inputSize + 1][hiddenSize + 1]; hidOptWeights = new double[hiddenSize + 1][outputSize + 1]; random = new Random(100000); //使每次产生的随机数都是第一次的分配,这是有参数和没参数的区别 randomizeWeights(iptHidWeights); //分配输入层到隐含层的神经元的权值 randomizeWeights(hidOptWeights); //分配隐含层到输出层的神经元的权值 iptHidPrevUptWeights = new double[inputSize + 1][hiddenSize + 1]; //更新权值 hidOptPrevUptWeights = new double[hiddenSize + 1][outputSize + 1]; this.eta = eta; //学习率 this.momentum = momentum; //动态量 } //随机产生神经元之间的权值信息 private void randomizeWeights(double[][] matrix) { for (int i = 0, len = matrix.length; i != len; i++) for (int j = 0, len2 = matrix[i].length; j != len2; j++) { double real = random.nextDouble(); //随机分配着产生0-1之间的值 matrix[i][j] = random.nextDouble() > 0.5 ? real : -real; } } //初始化输入层,隐含层,和输出层 public BP(int inputSize, int hiddenSize, int outputSize) { this(inputSize, hiddenSize, outputSize, 0.25, 0.9); } public void train(double[] trainData, double[] target) { //训练数据 loadInput(trainData); //加载输入的数据 loadTarget(target); //加载输出的结果数据 forward(); //向前计算神经元权值(先算输入到隐含层的,然后再算隐含到输出层的权值) calculateDelta(); //计算误差逆传播值 adjustWeight(); //调整更新神经元的权值 } //测试自己弄的BP神经网络训练的效果咋样 public double[] test(double[] inData) { if (inData.length != input.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(inData, 0, input, 1, inData.length); forward(); return getNetworkOutput(); } //返回最后的输出层的结果 private double[] getNetworkOutput() { int len = output.length; double[] temp = new double[len - 1]; for (int i = 1; i != len; i++) temp[i - 1] = output[i]; return temp; } //将之前那些训练数据的结果加载进来,存放,方便训练 private void loadTarget(double[] arg) { if (arg.length != target.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(arg, 0, target, 1, arg.length); //方法和之前的输入数据一样,都是调用复制数据的方法 } //加载训练数据 private void loadInput(double[] inData) { if (inData.length != input.length - 1) { throw new IllegalArgumentException("Size Do Not Match."); } System.arraycopy(inData, 0, input, 1, inData.length); //调用系统复制数组的方法(存放输入的训练数据) } //向前计算各个神经元的权值 (输入层到隐含层的)(参数一:输入层的数据,二:隐含层的内容,三:输入到隐含的神经元的权值) //向前计算各个神经元的权值 (隐含层到输出层)(参数一:隐含层的数据,二:输出层的内容,三:隐含层到输出层的神经元的权值) private void forward(double[] layer0, double[] layer1, double[][] weight) { layer0[0] = 1.0; for (int j = 1, len = layer1.length; j != len; ++j) { double sum = 0; //保存权值 for (int i = 0, len2 = layer0.length; i != len2; ++i) sum += weight[i][j] * layer0[i]; layer1[j] = sigmoid(sum); //调用神经元的激活函数来得到结果(结果肯定是在0-1之间的) } } //向前计算(先算输入到隐含层的,然后再算隐含到输出层的权值) private void forward() { forward(input, hidden, iptHidWeights); //计算输入层到隐含层的权值 forward(hidden, output, hidOptWeights); //计算隐含层到输出层的权值 } //计算输出层的误差 private void outputErr() { double errSum = 0; //误传播值 for (int idx = 1, len = optDelta.length; idx != len; ++idx) { double o = output[idx]; optDelta[idx] = o * (1d - o) * (target[idx] - o); //书上p104的公式 errSum += Math.abs(optDelta[idx]); } optErrSum = errSum; } //计算隐含层的误差 private void hiddenErr() { double errSum = 0; //保存误差 for (int j = 1, len = hidDelta.length; j != len; ++j) { double o = hidden[j]; //神经元权值 double sum = 0; for (int k = 1, len2 = optDelta.length; k != len2; ++k) //由输出层来反向计算 sum += hidOptWeights[j][k] * optDelta[k]; hidDelta[j] = o * (1d - o) * sum; //书上的P104的(5.15)公式 errSum += Math.abs(hidDelta[j]); } hidErrSum = errSum; } //计算每一层的误差(因为在BP中,要达到使误差最小)(就是逆传播算法,书上有P101) private void calculateDelta() { outputErr(); //计算输出层的误差(因为要反过来算,所以先算输出层的) hiddenErr(); //计算隐含层的误差 } //更新每层中的神经元的权值信息(这也就是不断的训练过程,) private void adjustWeight(double[] delta, double[] layer, double[][] weight, double[][] prevWeight) { layer[0] = 1; for (int i = 1, len = delta.length; i != len; ++i) { for (int j = 0, len2 = layer.length; j != len2; ++j) { double newVal = momentum * prevWeight[j][i] + eta * delta[i] * layer[j]; //通过公式计算误差限=(动态量*之前的该神经元的阈值+学习率*误差*对应神经元的阈值),来进行更新权值 weight[j][i] += newVal; //得到新的神经元之间的权值 prevWeight[j][i] = newVal; //保存这一次得到的权值,方便下一次进行更新 } } } //更新每层中的神经元的权值信息 private void adjustWeight() { adjustWeight(optDelta, hidden, hidOptWeights, hidOptPrevUptWeights); adjustWeight(hidDelta, input, iptHidWeights, iptHidPrevUptWeights); } //我这里用的是sigmoid激活函数,当然也可以用阶跃函数,看自己选择吧 private double sigmoid(double val) { return 1d / (1d + Math.exp(-val)); //这函数书上有P98页 } }3)进行训练和测试
package shenjingwangluo2;import java.io.File;import java.io.FileNotFoundException;import java.io.IOException;import java.text.DateFormat;import java.text.SimpleDateFormat;import java.util.ArrayList;import java.util.Date;import java.util.List;import java.util.Random;import java.util.Scanner;import javax.xml.crypto.Data;public class Test {public static void main(String[] args) throws Exception { Date data=new Date();DateFormat df=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");System.out.println(df.format(data)); BP bp = new BP(784, 30, 10); //分别代表784个输入层,一层隐含层其中是784个神经元,10个输出结果 //得到训练数据的结果 // File fileresult=new File("D:\\text.et"); //训练数据的结果 CSVFileUtil util = new CSVFileUtil("D:\\text.csv"); int resulthang=util.getRowNum(); //得到训练结果行数 int resultlie=util.getColNum(); //得到训练结果列数 // String[][] numberresult=GetNumberInfo.getData(fileresult, 0); //得到训练数据的结果// File inputdata=new File("D:\\traindata.et"); //训练数据的784个像素 CSVFileUtil util2 = new CSVFileUtil("D:\\traindata.csv"); int inputhang=util2.getRowNum(); //得到训练结果行数 int inputlie=util2.getColNum(); //得到训练结果列数 // String[][] numberxunlian=GetNumberInfo.getData(inputdata, 0); for(int i=0;i<20;i++){ //训练迭代次数 for(int index=0;index<resulthang;index++){ double[] xunlianresult=new double[10]; //因为结果有0-9这10种情况 String getresult=util.getString(index, 0); //得到训练集的每个的结果 xunlianresult[Integer.parseInt(getresult)]=1; //表示当前的这个训练集的结果是对应的下标的值 //将训练数据的每一个转换成只有0和1的形式 double[] binary = new double[784]; //转成二进制进行处理(784位) int suoyinlie = 0; int value=0; while(suoyinlie<784){ value=Integer.parseInt(util2.getString(index, suoyinlie)); if(value>=255/2){ binary[suoyinlie]=1; //主要是为了让数据中只有0和1这样的灰度数据方便计算 } else{ binary[suoyinlie]=0; } suoyinlie++; } bp.train(binary,xunlianresult); //训练数据 } } /* /////////////////////////////////之前测试用的 //输入层的训练数据 Random random = new Random(); List<Integer> list = new ArrayList<Integer>(); //存取1000个随机数(主要是用来当输入层数据的) for (int i = 0; i <= 1000; i++) { int value = random.nextInt(); list.add(value); } for (int i = 0; i <= 200; i++) { for (int value : list) { //取训练数据 double[] real = new double[4]; if (value >= 0) //随机数是正数 if ((value & 1) == 1) //正奇数 real[0] = 1; else //正偶数 real[1] = 1; else if ((value & 1) == 1) //负奇数 real[2] = 1; else //负偶数 real[3] = 1; double[] binary = new double[32]; //转成二进制进行处理 int index = 31; do { binary[index--] = (value & 1); value >>>= 1; } while (value != 0); bp.train(binary, real); //训练数据 } } */ System.out.println("神经网络拓扑结构训练好了,可以进行测试"); computeTextData(bp);} /* while (true) { byte[] input = new byte[10]; System.in.read(input); Integer value = Integer.parseInt(new String(input).trim()); //输入测试数据 Scanner sc=new Scanner(System.in); int value=sc.nextInt(); int rawVal = value; double[] binary = new double[32]; int index = 31; do { binary[index--] = (value & 1); value >>>= 1; } while (value != 0); double[] result = bp.test(binary); double max = -Integer.MIN_VALUE; int idx = -1; for (int i = 0; i != result.length; i++) { if (result[i] > max) { //得到输出层中最大的权值,就可以得到它的属性是哪个数字了 max = result[i]; idx = i; //保存是属于哪一个输出层的数值,也就是代表它属于哪一类 } } switch (idx) { case 0: System.out.format("%d是一个正奇数\n", rawVal); break; case 1: System.out.format("%d是一个正偶数\n", rawVal); break; case 2: System.out.format("%d是一个负奇数\n", rawVal); break; case 3: System.out.format("%d是一个负偶数\n", rawVal); break; } } } *///输出测试的结果private static void computeTextData(BP bp) throws Exception {// File textinputdata=new File("D:\\textdata.et"); // String[][] textdata=GetNumberInfo.getData(textinputdata, 0); //得到测试数据CSVFileUtil util3 = new CSVFileUtil("D:\\textdata.csv"); int textthang=util3.getRowNum(); //得到测试数据行数 int textlie=util3.getColNum(); //得到测试数据列数 CSVFileUtil util4 = new CSVFileUtil("D:\\textresult.csv"); int textresultthang=util4.getRowNum(); //得到测试结果行数 int textresultlie=util4.getColNum(); //得到测试结果列数 int textrightnumber=0; //得到预测和实际结果相同的个数(计算正确率) for(int hanggeshu=0;hanggeshu<textthang;hanggeshu++){ double[] binary = new double[784]; //转成二进制进行处理(784位) int suoyinlie = 0; int value=0; while(suoyinlie<784){ value=Integer.parseInt(util3.getString(hanggeshu, suoyinlie)); if(value>=255/2){ binary[suoyinlie]=1; //主要是为了让数据中只有0和1这样的灰度数据方便计算 } else{ binary[suoyinlie]=0; } suoyinlie++; } double[] result = bp.test(binary); //测试的结果 double max = result[0]; int idx = -1; for(int a=0;a<10;a++){ } for (int i = 0; i <result.length; i++) { if (result[i] > max) { //得到输出层中最大的权值,就可以得到它的属性是哪个数字了 max = result[i]; idx = i; //保存是属于哪一个输出层的数值,也就是代表它属于哪一类 } } if(idx==Integer.parseInt(util4.getString(hanggeshu,0))){ //计算正确率 textrightnumber++; } switch (idx) {case 0:System.out.println("测试结果是0");break;case 1:System.out.println("测试结果是1");break;case 2:System.out.println("测试结果是2");break;case 3:System.out.println("测试结果是3");break;case 4:System.out.println("测试结果是4");break;case 5:System.out.println("测试结果是5");break;case 6:System.out.println("测试结果是6");break;case 7:System.out.println("测试结果是7");break;case 8:System.out.println("测试结果是8");break;case 9:System.out.println("测试结果是9");break;default:break;} } Date data2=new Date();DateFormat df=new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");System.out.println(df.format(data2)); System.out.println("正确率为:"+(double)(textrightnumber)/textresultthang);} }
好了,这就是大概的一个过程,其中参考了网上的一些博客资源,再次感谢,欢迎交流。
下面把数据的链接贴出来(百度云盘)。
https://pan.baidu.com/s/1boFR2Bd
- Java实现手写数字的识别(BP神经网络的运用)
- 用BP人工神经网络识别手写数字
- 神经网络实现手写数字识别(MNIST)
- 基于BP神经网络的数字识别
- 基于BP神经网络的数字识别基础系统(一)
- 基于BP神经网络的数字识别基础系统(二)
- BP神经网络应用于手写数字识别--matlab程序
- BP神经网络识别手写数字项目解析及代码
- 简单BP神经网络分类手写数字识别0-9
- BP神经网络的Java实现
- bp神经网络的java实现
- 机器学习(四):BP神经网络_手写数字识别_Python
- 机器学习之 神经网络的实现(二)-->手写识别
- bp实现手写识别
- 神经网络-tensorflow实现mnist手写数字识别
- 基于BP人工神经网络的数字字符识别及MATLAB实现
- 基于神经网络和遗传算法的【手写数字识别】机器人的实现
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- NYOJ skiing
- 【SHOI&SXOI2017】bzoj4871 摧毁“树状图”
- JSON简单使用
- 河工大校赛J 爱看电视的LsF
- BZOJ 1336: [Balkan2002]Alien最小圆覆盖 随机增量法
- Java实现手写数字的识别(BP神经网络的运用)
- Android硬件加速相关知识点总结
- struts内部错误 antlr.2.7.2 版本冲突
- 一个栈的入栈序列为ABCDEF,则不可能的出栈序列是
- 8086汇编学习之[BX],CX寄存器与loop指令,ES寄存器等
- Debian 自动运行机制
- Tomcat卡死的情况
- 上古神器sed命令(上)
- java开发之文件路径剖析