Neural Network实战:Java实现Back Propagation算法 + 手写数字识别
来源:互联网 发布:陈田拆车件市场淘宝 编辑:程序博客网 时间:2024/06/05 18:40
工作中时不时的会使用机器学习解决一些分类问题。但是一般都是使用已有的机器学习库,比如Weka,Scikit-learn等简单易用的库。
对于一个工程师, 够理解常见的机器学习模型(比如SVM,Naive bayes, Random forest, Neural network, Decision tree, 等等) + 降维技术(比如PCA, SVD, 等),使用已有的库, 来解决常见的工程问题。达到这个境界是不是就够了呢?
我个人觉得不是很满足。对于一些模型,已经训练模型的算法,虽然我看过不少书籍和文章,对于其中奥秘,自我感觉能够理解。但是作为工程师,不是自己实现的,用起来总是惴惴不安。或者自己能够实现过,再用现成的工业级别的库,也会心安理得一点。
那好,这篇文章,我就来用Java实现经典的Neural Network。训练的算法使用经典的Back Propagation算法。简单说两句BP算法:该算法使用最速梯度下降来求解目标函数的最小值。改目标函数是非凸的,因此使用梯度下降,容易求得次优解。一些解决办法包括:使用物理上冲量概念,当落入次优解的时候,算法本身有一定能力冲出次优解,再次滑向最优解。
本篇涉及到的所有code都放到我的github上了: https://github.com/zhangfaen/ML/tree/master/neural_network
详细的参考:http://en.wikipedia.org/wiki/Artificial_neural_network 。 另外,引用一张来自Andrew Moore的slide,该slide深刻的描述了BP的本质
我今天自己手工推导了一遍,涉及到的主要知识点是:复合函数求偏微分。
完成了推导,我们开始实现。
实现一个Java类
package faen;import java.util.Arrays;import java.util.Random;// http://en.wikipedia.org/wiki/Artificial_neural_network// A kind of non linear model of Machine Learning.public class NN { static class Util { public static void CHECK(boolean condition, String message) { if (!condition) { throw new RuntimeException(message); } } } private int expandedInputNodes; private int hiddenNodes; private int outputNodes; // Weights matrix between input layer and hidden layer private double[][] wi; // Weights matrix between hidden layer and output layer. private double[][] wo; // last change in weights for momentum private double[][] wi_momentum; // last change in weights for momentum private double[][] wo_momentum; // Expanded instance, whose size is this.outputSize + 1. // The last element will be fixed to 1.0 private double[] expandedInstance; private double[] hiddenActivations; private double[] outputActivations; // The sigmoid function: s(x) = 1 / (1 + (e^-x)) // The derivative of s(x): s(x) * (1 - s(x)) private double s(double x) { return 1.0 / (1.0 + Math.pow(Math.E, -x)); } public NN(int featuresOfInstance, int nodesOfHiddenLayer, int nodesOfOutputLayer) { Util.CHECK(featuresOfInstance > 0, ""); Util.CHECK(nodesOfHiddenLayer > 0, ""); Util.CHECK(nodesOfOutputLayer > 0, ""); this.expandedInputNodes = featuresOfInstance + 1; this.hiddenNodes = nodesOfHiddenLayer; this.outputNodes = nodesOfOutputLayer; this.wi = new double[this.expandedInputNodes][this.hiddenNodes]; this.wo = new double[this.hiddenNodes][this.outputNodes]; this.wi_momentum = new double[this.expandedInputNodes][this.hiddenNodes]; this.wo_momentum = new double[this.hiddenNodes][this.outputNodes]; this.expandedInstance = new double[this.expandedInputNodes]; this.expandedInstance[this.expandedInputNodes - 1] = 1.0; this.hiddenActivations = new double[this.hiddenNodes]; this.outputActivations = new double[this.outputNodes]; } // Randomly initialize the input and output weights matrix private void initializeWeights() { Random rand = new Random(); for (int i = 0; i < this.wi.length; i++) { for (int j = 0; j < this.wi[0].length; j++) { // [-2.0, 2.0] this.wi[i][j] = rand.nextDouble() * 4 - 2; } } for (int i = 0; i < this.wo.length; i++) { for (int j = 0; j < this.wo[0].length; j++) { // [-2.0, 2.0] this.wo[i][j] = rand.nextDouble() * 4 - 2; } } } private void forwardPropagate(double[] instance) { Util.CHECK(instance.length + 1 == this.expandedInputNodes, ""); for (int i = 0; i < instance.length; i++) { // Note: the last element of this.expandedInstance will be 1.0 this.expandedInstance[i] = instance[i]; } // forward propagation for (int j = 0; j < this.hiddenNodes; j++) { double tmp = 0; for (int i = 0; i < this.expandedInputNodes; i++) { tmp += this.expandedInstance[i] * this.wi[i][j]; } this.hiddenActivations[j] = s(tmp); } for (int k = 0; k < this.outputNodes; k++) { double tmp = 0; for (int j = 0; j < this.hiddenNodes; j++) { tmp += this.hiddenActivations[j] * this.wo[j][k]; } this.outputActivations[k] = s(tmp); } } // Predicate output from one instance. public double[] predicate(double[] instance) { forwardPropagate(instance); return this.outputActivations.clone(); } // Update NN weights by one instance and its label. private double feedOneInstance(double[] instance, double[] target, double rate, double momentum) { forwardPropagate(instance); double error = 0; for (int k = 0; k < this.outputNodes; k++) { error += 0.5 * (target[k] - this.outputActivations[k]) * (target[k] - this.outputActivations[k]); } // backward propagation // update output weights matrix for (int j = 0; j < this.hiddenNodes; j++) { for (int k = 0; k < this.outputNodes; k++) { // wo[j,k] double change = (this.outputActivations[k] - target[k]) * this.outputActivations[k] * (1 - this.outputActivations[k]); change *= this.hiddenActivations[j]; this.wo[j][k] = this.wo[j][k] - rate * change - momentum * this.wo_momentum[j][k]; this.wo_momentum[j][k] = change; } } // update input weights matrix for (int i = 0; i < this.expandedInputNodes; i++) { for (int j = 0; j < this.hiddenNodes; j++) { // wi[i, j] double change = 0; for (int k = 0; k < this.outputNodes; k++) { change += (this.outputActivations[k] - target[k]) * this.outputActivations[k] * (1 - this.outputActivations[k]) * this.wo[j][k]; } change *= this.hiddenActivations[j] * (1 - this.hiddenActivations[j]); change *= this.expandedInstance[i]; this.wi[i][j] = this.wi[i][j] - rate * change - momentum * this.wi_momentum[i][j]; this.wi_momentum[i][j] = change; } } return error; } // Train the NN public void train(double[][] instances, double[][] targets, int iterations, double rate, double momentum) { Util.CHECK(instances.length == targets.length && targets.length > 0, ""); Util.CHECK(instances[0].length > 0, ""); Util.CHECK(targets[0].length == this.outputNodes, ""); initializeWeights(); for (int it = 0; it < iterations; it++) { double error = 0; for (int index = 0; index < instances.length; index++) { double[] instance = instances[index]; double[] target = targets[index]; error += feedOneInstance(instance, target, rate, momentum); } if (it % 20 == 0) { System.out.println("error: " + error); } } } // Bits XOR private static void demo1() { double[][] instances = new double[][] { { 1, 0 }, { 1, 1 }, { 0, 1 }, { 0, 0 } }; double[][] targets = new double[][] { { 1 }, { 0 }, { 1 }, { 0 } }; NN nn = new NN(2, 4, 1); nn.train(instances, targets, 10000, 1.5, 0.2); System.out.println("1 xor 1: " + Arrays.toString(nn.predicate(new double[] { 1, 1 }))); System.out.println("1 xor 0: " + Arrays.toString(nn.predicate(new double[] { 1, 0 }))); System.out.println("0 xor 0: " + Arrays.toString(nn.predicate(new double[] { 0, 0 }))); System.out.println("0 xor 1: " + Arrays.toString(nn.predicate(new double[] { 0, 1 }))); } // Data points are along 2 circles: x^2 + y^2 = 2 and x^2 + y^2 = 4 // The data points are along the first circle have label 0. // The data points are along the second circle have label 1. // Note: we should not use labels 2 and 4, because the output is an // activation function, // "> 0" means activated. private static void demo2() { int m = 0; double step = 0.5; for (double i = 0; i < Math.PI * 2; i += step) { m++; } double[][] instances = new double[m * 2][2]; double[][] targets = new double[m * 2][1]; int index = 0; for (double i = 0; i < Math.PI * 2; i += step) { instances[index][0] = 2 * Math.cos(i); instances[index][1] = 2 * Math.sin(i); instances[index + m][0] = 4 * Math.cos(i); instances[index + m][1] = 4 * Math.sin(i); targets[index][0] = 0; targets[index + m][0] = 1; index++; } NN nn = new NN(2, 100, 1); nn.train(instances, targets, 5000, 0.5, 0.2); // Testing. for (double i = 0.2; i < Math.PI * 2; i += 1) { double x = 2 * Math.cos(i); double y = 2 * Math.sin(i); System.out.println("x: " + x + " y:" + y + " r:" + (Math.hypot(x, y)) + " result:" + Arrays.toString(nn.predicate(new double[] { x, y }))); x = 4 * Math.cos(i); y = 4 * Math.sin(i); System.out.println("x: " + x + " y:" + y + " r:" + (Math.hypot(x, y)) + " result:" + Arrays.toString(nn.predicate(new double[] { x, y }))); } } public static void main(String[] args) { demo1(); System.out.println("-----------------------------------------"); demo2(); }}
我们再来用完成的NN来识别手写的数字。手写的数字都已经处理为28*28的黑白图片,比如数字5
0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001110000000000000000011111111111000000000000000001111111110000000000000000001100000000000000000000000000110000000000000000000000000110000000000000000000000000011000000000000000000000000001111111100000000000000000001111111111100000000000000000010000000111000000000000000000000000001110000000000000000000000000011000000000000000000000000000100000000000000000000000000011000000000000000000000000011100000000000000011100000011100000000000000011000000011110000000000000001100000011110000000000000000011111111100000000000000000000111111000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
由于4000个这样的数字,有7MB左右。我不能都上传上来。但是所有的code和数据都可以我的github上找到: https://github.com/zhangfaen/ML/tree/master/neural_network
不管怎样,还是把手写数字识别的程序贴到下面了。
package faen;import java.io.BufferedReader;import java.io.FileReader;import java.io.IOException;public class Main { private static int BLACK_WHITE_THRESHOLD = 100; private static void printImage(int[][] pic) { // System.out.println("total:" + total); for (int i = 0; i < pic.length; i++) { for (int j = 0; j < pic[0].length; j++) { if (pic[i][j] > 0) { System.out.print(1); } else { System.out.print(0); } } System.out.println(); } System.out.println("---------"); } private static int getFeatureValue(int raw) { if (raw >= BLACK_WHITE_THRESHOLD) { return 1; } return 0; } private static void readCsv(int m_train, int m_test) throws IOException { BufferedReader br = new BufferedReader(new FileReader( "/Users/zhangfaen/dev/ml/kaggle/neural_network/data/train.csv")); int total = 0; // one instance is 28*28 picture. instances = new double[m_train][28 * 28]; targets = new double[m_train][10]; test_instances = new double[m_test][28 * 28]; test_targets = new double[m_test][10]; // The first line is header. br.readLine(); while (true) { String line = br.readLine(); if (line == null) { break; } String[] lineArray = line.split(","); int m = 28; int n = 28; int[][] image = new int[m][n]; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { image[i][j] = getFeatureValue(Integer.parseInt(lineArray[i * m + j + 1])); if (total < m_train) { instances[total][i * m + j] = image[i][j]; } else { test_instances[total - m_train][i * m + j] = image[i][j]; } } } if (total < m_train) { targets[total][Integer.parseInt(lineArray[0])] = 1; } else { test_targets[total - m_train][Integer.parseInt(lineArray[0])] = 1; } if (total % 200 == 1) { printImage(image); } if (++total >= m_train + m_test) { break; } } br.close(); } private static double[][] instances = null; private static double[][] targets = null; private static double[][] test_instances = null; private static double[][] test_targets = null; private static int get(double[] output) { int actual_index = 0; double actual_best = output[0]; for (int j = 1; j < output.length; j++) { if (actual_best < output[j]) { actual_index = j; actual_best = output[j]; } } return actual_index; } public static void main(String[] args) throws Exception { readCsv(3000, 200); NN nn = new NN(28 * 28, 10, 10); nn.train(instances, targets, 300, 1.5, 0.2); int correct = 0; int wrong = 0; for (int i = 0; i < test_targets.length; i++) { double[] actual = nn.predicate(test_instances[i]); System.out.println("actual: " + get(actual) + ", expected: " + get(test_targets[i])); if (get(actual) == get(test_targets[i])) { correct++; } else { wrong++; } } System.out.println("correct:" + correct + ", wrong:" + wrong + ", accuracy:" + 1.0 * correct / (correct + wrong)); }}
- Neural Network实战:Java实现Back Propagation算法 + 手写数字识别
- Java Back Propagation Neural Network(JAVA反向传播神经网络)
- Neural Network(神经网络)实例--手写数字识别
- Random Forest实战:Java实现 + 手写数字识别
- Neural Network based on Eorr Back Propagation典型BP网络c++实现
- Neural Network中的Back-Propagation的详细推导过程
- Back-Propagation Neural Networks
- Back-propagation Neural Net(BP神经网络)算法实现介绍浅析
- knn算法实现的数字手写识别
- Python实现KNN算法手写识别数字
- Python实现knn算法手写数字识别
- K近邻算法(一) python实现,手写数字识别(from机器学习实战)
- 手写数字识别实现
- KNN算法实战——手写数字识别
- [TensorFlow实战] 构建LeNet实现手写数字识别
- What is meant by back propagation in an ANN compared to a biological neural network?
- 学习笔记——《机器学习实战》KNN算法实现 约会网站测试,手写数字识别,代码,注释,错误修改
- 基于K-近邻算法识别手写数字的实现
- android 项目源码
- Configure with --host, --target and --build options
- 将JSON对象转换成IList,好用linq
- 解决IE下返回json数据提示文件保存到问题
- 1069.The Black Hole of Numbers
- Neural Network实战:Java实现Back Propagation算法 + 手写数字识别
- 1070.Mooncake
- ASP.NET实现验证码记录
- mysql触发器
- R语言学习(二)
- POJ1789 Truck History【Prim】
- DEDE CMS安全警告:Safe Alert: Request Error step 2!
- Java线程:新特征-有返回值的线程
- json2使用方法