Neural Network实战:Java实现Back Propagation算法 + 手写数字识别

来源:互联网 发布:陈田拆车件市场淘宝 编辑:程序博客网 时间:2024/06/05 18:40

工作中时不时的会使用机器学习解决一些分类问题。但是一般都是使用已有的机器学习库,比如Weka,Scikit-learn等简单易用的库。

对于一个工程师, 够理解常见的机器学习模型(比如SVM,Naive bayes, Random forest, Neural network, Decision tree, 等等) + 降维技术(比如PCA, SVD, 等),使用已有的库, 来解决常见的工程问题。达到这个境界是不是就够了呢?

我个人觉得不是很满足。对于一些模型,已经训练模型的算法,虽然我看过不少书籍和文章,对于其中奥秘,自我感觉能够理解。但是作为工程师,不是自己实现的,用起来总是惴惴不安。或者自己能够实现过,再用现成的工业级别的库,也会心安理得一点。

那好,这篇文章,我就来用Java实现经典的Neural Network。训练的算法使用经典的Back Propagation算法。简单说两句BP算法:该算法使用最速梯度下降来求解目标函数的最小值。改目标函数是非凸的,因此使用梯度下降,容易求得次优解。一些解决办法包括:使用物理上冲量概念,当落入次优解的时候,算法本身有一定能力冲出次优解,再次滑向最优解。

本篇涉及到的所有code都放到我的github上了: https://github.com/zhangfaen/ML/tree/master/neural_network

 详细的参考:http://en.wikipedia.org/wiki/Artificial_neural_network 。 另外,引用一张来自Andrew Moore的slide,该slide深刻的描述了BP的本质


我今天自己手工推导了一遍,涉及到的主要知识点是:复合函数求偏微分。


完成了推导,我们开始实现。

实现一个Java类

package faen;import java.util.Arrays;import java.util.Random;// http://en.wikipedia.org/wiki/Artificial_neural_network// A kind of non linear model of Machine Learning.public class NN {    static class Util {        public static void CHECK(boolean condition, String message) {            if (!condition) {                throw new RuntimeException(message);            }        }    }    private int expandedInputNodes;    private int hiddenNodes;    private int outputNodes;    // Weights matrix between input layer and hidden layer    private double[][] wi;    // Weights matrix between hidden layer and output layer.    private double[][] wo;    // last change in weights for momentum    private double[][] wi_momentum;    // last change in weights for momentum    private double[][] wo_momentum;    // Expanded instance, whose size is this.outputSize + 1.    // The last element will be fixed to 1.0    private double[] expandedInstance;    private double[] hiddenActivations;    private double[] outputActivations;    // The sigmoid function: s(x) = 1 / (1 + (e^-x))    // The derivative of s(x): s(x) * (1 - s(x))    private double s(double x) {        return 1.0 / (1.0 + Math.pow(Math.E, -x));    }    public NN(int featuresOfInstance, int nodesOfHiddenLayer, int nodesOfOutputLayer) {        Util.CHECK(featuresOfInstance > 0, "");        Util.CHECK(nodesOfHiddenLayer > 0, "");        Util.CHECK(nodesOfOutputLayer > 0, "");        this.expandedInputNodes = featuresOfInstance + 1;        this.hiddenNodes = nodesOfHiddenLayer;        this.outputNodes = nodesOfOutputLayer;        this.wi = new double[this.expandedInputNodes][this.hiddenNodes];        this.wo = new double[this.hiddenNodes][this.outputNodes];        this.wi_momentum = new double[this.expandedInputNodes][this.hiddenNodes];        this.wo_momentum = new double[this.hiddenNodes][this.outputNodes];        this.expandedInstance = new double[this.expandedInputNodes];        this.expandedInstance[this.expandedInputNodes - 1] = 1.0;        this.hiddenActivations = new double[this.hiddenNodes];        this.outputActivations = new double[this.outputNodes];    }    // Randomly initialize the input and output weights matrix    private void initializeWeights() {        Random rand = new Random();        for (int i = 0; i < this.wi.length; i++) {            for (int j = 0; j < this.wi[0].length; j++) {                // [-2.0, 2.0]                this.wi[i][j] = rand.nextDouble() * 4 - 2;            }        }        for (int i = 0; i < this.wo.length; i++) {            for (int j = 0; j < this.wo[0].length; j++) {                // [-2.0, 2.0]                this.wo[i][j] = rand.nextDouble() * 4 - 2;            }        }    }    private void forwardPropagate(double[] instance) {        Util.CHECK(instance.length + 1 == this.expandedInputNodes, "");        for (int i = 0; i < instance.length; i++) {            // Note: the last element of this.expandedInstance will be 1.0            this.expandedInstance[i] = instance[i];        }        // forward propagation        for (int j = 0; j < this.hiddenNodes; j++) {            double tmp = 0;            for (int i = 0; i < this.expandedInputNodes; i++) {                tmp += this.expandedInstance[i] * this.wi[i][j];            }            this.hiddenActivations[j] = s(tmp);        }        for (int k = 0; k < this.outputNodes; k++) {            double tmp = 0;            for (int j = 0; j < this.hiddenNodes; j++) {                tmp += this.hiddenActivations[j] * this.wo[j][k];            }            this.outputActivations[k] = s(tmp);        }    }    // Predicate output from one instance.    public double[] predicate(double[] instance) {        forwardPropagate(instance);        return this.outputActivations.clone();    }    // Update NN weights by one instance and its label.    private double feedOneInstance(double[] instance, double[] target, double rate, double momentum) {        forwardPropagate(instance);        double error = 0;        for (int k = 0; k < this.outputNodes; k++) {            error += 0.5 * (target[k] - this.outputActivations[k])                    * (target[k] - this.outputActivations[k]);        }        // backward propagation        // update output weights matrix        for (int j = 0; j < this.hiddenNodes; j++) {            for (int k = 0; k < this.outputNodes; k++) {                // wo[j,k]                double change = (this.outputActivations[k] - target[k]) * this.outputActivations[k]                        * (1 - this.outputActivations[k]);                change *= this.hiddenActivations[j];                this.wo[j][k] = this.wo[j][k] - rate * change - momentum * this.wo_momentum[j][k];                this.wo_momentum[j][k] = change;            }        }        // update input weights matrix        for (int i = 0; i < this.expandedInputNodes; i++) {            for (int j = 0; j < this.hiddenNodes; j++) {                // wi[i, j]                double change = 0;                for (int k = 0; k < this.outputNodes; k++) {                    change += (this.outputActivations[k] - target[k]) * this.outputActivations[k]                            * (1 - this.outputActivations[k]) * this.wo[j][k];                }                change *= this.hiddenActivations[j] * (1 - this.hiddenActivations[j]);                change *= this.expandedInstance[i];                this.wi[i][j] = this.wi[i][j] - rate * change - momentum * this.wi_momentum[i][j];                this.wi_momentum[i][j] = change;            }        }        return error;    }    // Train the NN    public void train(double[][] instances, double[][] targets, int iterations, double rate,            double momentum) {        Util.CHECK(instances.length == targets.length && targets.length > 0, "");        Util.CHECK(instances[0].length > 0, "");        Util.CHECK(targets[0].length == this.outputNodes, "");        initializeWeights();        for (int it = 0; it < iterations; it++) {            double error = 0;            for (int index = 0; index < instances.length; index++) {                double[] instance = instances[index];                double[] target = targets[index];                error += feedOneInstance(instance, target, rate, momentum);            }            if (it % 20 == 0) {                System.out.println("error: " + error);            }        }    }    // Bits XOR    private static void demo1() {        double[][] instances = new double[][] { { 1, 0 }, { 1, 1 }, { 0, 1 }, { 0, 0 } };        double[][] targets = new double[][] { { 1 }, { 0 }, { 1 }, { 0 } };        NN nn = new NN(2, 4, 1);        nn.train(instances, targets, 10000, 1.5, 0.2);        System.out.println("1 xor 1: " + Arrays.toString(nn.predicate(new double[] { 1, 1 })));        System.out.println("1 xor 0: " + Arrays.toString(nn.predicate(new double[] { 1, 0 })));        System.out.println("0 xor 0: " + Arrays.toString(nn.predicate(new double[] { 0, 0 })));        System.out.println("0 xor 1: " + Arrays.toString(nn.predicate(new double[] { 0, 1 })));    }    // Data points are along 2 circles: x^2 + y^2 = 2 and x^2 + y^2 = 4    // The data points are along the first circle have label 0.    // The data points are along the second circle have label 1.    // Note: we should not use labels 2 and 4, because the output is an    // activation function,    // "> 0" means activated.    private static void demo2() {        int m = 0;        double step = 0.5;        for (double i = 0; i < Math.PI * 2; i += step) {            m++;        }        double[][] instances = new double[m * 2][2];        double[][] targets = new double[m * 2][1];        int index = 0;        for (double i = 0; i < Math.PI * 2; i += step) {            instances[index][0] = 2 * Math.cos(i);            instances[index][1] = 2 * Math.sin(i);            instances[index + m][0] = 4 * Math.cos(i);            instances[index + m][1] = 4 * Math.sin(i);            targets[index][0] = 0;            targets[index + m][0] = 1;            index++;        }        NN nn = new NN(2, 100, 1);        nn.train(instances, targets, 5000, 0.5, 0.2);        // Testing.        for (double i = 0.2; i < Math.PI * 2; i += 1) {            double x = 2 * Math.cos(i);            double y = 2 * Math.sin(i);            System.out.println("x: " + x + " y:" + y + " r:" + (Math.hypot(x, y)) + " result:"                    + Arrays.toString(nn.predicate(new double[] { x, y })));            x = 4 * Math.cos(i);            y = 4 * Math.sin(i);            System.out.println("x: " + x + " y:" + y + " r:" + (Math.hypot(x, y)) + " result:"                    + Arrays.toString(nn.predicate(new double[] { x, y })));        }    }    public static void main(String[] args) {        demo1();        System.out.println("-----------------------------------------");        demo2();    }}

我们再来用完成的NN来识别手写的数字。手写的数字都已经处理为28*28的黑白图片,比如数字5

0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001110000000000000000011111111111000000000000000001111111110000000000000000001100000000000000000000000000110000000000000000000000000110000000000000000000000000011000000000000000000000000001111111100000000000000000001111111111100000000000000000010000000111000000000000000000000000001110000000000000000000000000011000000000000000000000000000100000000000000000000000000011000000000000000000000000011100000000000000011100000011100000000000000011000000011110000000000000001100000011110000000000000000011111111100000000000000000000111111000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000


由于4000个这样的数字,有7MB左右。我不能都上传上来。但是所有的code和数据都可以我的github上找到: https://github.com/zhangfaen/ML/tree/master/neural_network

不管怎样,还是把手写数字识别的程序贴到下面了。

package faen;import java.io.BufferedReader;import java.io.FileReader;import java.io.IOException;public class Main {    private static int BLACK_WHITE_THRESHOLD = 100;    private static void printImage(int[][] pic) {        // System.out.println("total:" + total);        for (int i = 0; i < pic.length; i++) {            for (int j = 0; j < pic[0].length; j++) {                if (pic[i][j] > 0) {                    System.out.print(1);                } else {                    System.out.print(0);                }            }            System.out.println();        }        System.out.println("---------");    }    private static int getFeatureValue(int raw) {        if (raw >= BLACK_WHITE_THRESHOLD) {            return 1;        }        return 0;    }    private static void readCsv(int m_train, int m_test) throws IOException {        BufferedReader br = new BufferedReader(new FileReader(                "/Users/zhangfaen/dev/ml/kaggle/neural_network/data/train.csv"));        int total = 0;        // one instance is 28*28 picture.        instances = new double[m_train][28 * 28];        targets = new double[m_train][10];        test_instances = new double[m_test][28 * 28];        test_targets = new double[m_test][10];        // The first line is header.        br.readLine();        while (true) {            String line = br.readLine();            if (line == null) {                break;            }            String[] lineArray = line.split(",");            int m = 28;            int n = 28;            int[][] image = new int[m][n];            for (int i = 0; i < m; i++) {                for (int j = 0; j < n; j++) {                    image[i][j] = getFeatureValue(Integer.parseInt(lineArray[i * m + j + 1]));                    if (total < m_train) {                        instances[total][i * m + j] = image[i][j];                    } else {                        test_instances[total - m_train][i * m + j] = image[i][j];                    }                }            }            if (total < m_train) {                targets[total][Integer.parseInt(lineArray[0])] = 1;            } else {                test_targets[total - m_train][Integer.parseInt(lineArray[0])] = 1;            }            if (total % 200 == 1) {                printImage(image);            }            if (++total >= m_train + m_test) {                break;            }        }        br.close();    }    private static double[][] instances = null;    private static double[][] targets = null;    private static double[][] test_instances = null;    private static double[][] test_targets = null;    private static int get(double[] output) {        int actual_index = 0;        double actual_best = output[0];        for (int j = 1; j < output.length; j++) {            if (actual_best < output[j]) {                actual_index = j;                actual_best = output[j];            }        }        return actual_index;    }    public static void main(String[] args) throws Exception {        readCsv(3000, 200);        NN nn = new NN(28 * 28, 10, 10);        nn.train(instances, targets, 300, 1.5, 0.2);        int correct = 0;        int wrong = 0;        for (int i = 0; i < test_targets.length; i++) {            double[] actual = nn.predicate(test_instances[i]);            System.out.println("actual: " + get(actual) + ", expected: " + get(test_targets[i]));            if (get(actual) == get(test_targets[i])) {                correct++;            } else {                wrong++;            }        }        System.out.println("correct:" + correct + ", wrong:" + wrong + ", accuracy:" + 1.0                * correct / (correct + wrong));    }}



0 2