多层前馈神经元网络

来源:互联网 发布:淘宝网男士冬季休闲鞋 编辑:程序博客网 时间:2024/04/30 19:59

做了一个神经元网络分类器。开始步长设置为迭代次数的倒数,效果不好;后来调整到 0.2 效果比较好。测试一个抛物线边界的例子,准确率大约 96% 以上。

public final class NeuroNetwork {private static class Neurode {double err;double output;double theta;}private static enum Status {NEW,TRAINED;}// status of this class, either NEW or TRAINEDprivate Status status;// depth of network, layer 0 is input layerprivate int depth;// neurodes in each layerprivate Neurode[][] neurodes;// weights[i] is a two dimensional array, representing weights between layer i and layer 1+1private double[][][] weights;// initialize the neuronetwork/** * Initialize the neuronetwork *  * @param depth: the number of layers * @param numNeurodes: the number of neurodes in each layer */public NeuroNetwork(int depth, int[] numNeurodes) {this.depth = depth;// create and initialize neurodesneurodes = new Neurode[depth][];for ( int d=0; d<depth; d++ ) {neurodes[d] = new Neurode[numNeurodes[d]];for ( int i=0; i<numNeurodes[d]; i++) {neurodes[d][i] = new Neurode();neurodes[d][i].theta = Math.random();}}// initialize weightsweights = new double[depth][][];for ( int d=0; d<depth-1; d++ ) {weights[d] = new double[numNeurodes[d]][numNeurodes[d+1]];for ( int i=0; i<numNeurodes[d]; i++) {for ( int j=0; j<numNeurodes[d+1]; j++ ) {weights[d][i][j] = Math.random();}}}status = Status.NEW;}/** * Calculate output given a input *  * @param data: an vector representing input */private void calculateOutput(double[] data) {// initial output of layer 0for (int i=0; i<neurodes[0].length; i++ ) {neurodes[0][i].output = data[i];}// calculate output for each output layerfor ( int d=1; d<depth; d++ ) {for ( int j=0; j<neurodes[d].length; j++) {double input = 0.0;for ( int i=0; i<neurodes[d-1].length; i++ ) {input += neurodes[d-1][i].output*weights[d-1][i][j];}input += neurodes[d][j].theta;neurodes[d][j].output = 1.0/(1.0+Math.exp(0.0-input));}}}/** * Classify and predict *  * @param data: an vector represent one entry of taining sample * @param target: an vector represent class label of the training sample */public int predict(double[] data, double[] output) {if ( data.length != neurodes[0].length || output.length != neurodes[depth-1].length ) {throw  new IllegalArgumentException();}calculateOutput(data);double x = neurodes[depth-1][0].output;int label = 0;for ( int i=0; i<neurodes[depth-1].length; i++ ) {output[i] = neurodes[depth-1][i].output;if ( x < output[i] ) {x = output[i];label = i;}}return label;}/** * Train the neuronetwork *  * @param data: input matrix of train data, with data[i] represents the ith sample * @param target: input matrix of train label, with target[i] represents the ith label * @param maxIteration : maximum times of interation * @param threshold : threshold of weights update * @param errorRate : threshold for error rate * @return */public boolean train(double[][] data, double target[][], int maxIteration, double threshold, double errorRate) {// check statusif ( status == Status.TRAINED ){throw new IllegalStateException();}// check input arguments and input parametersif ( data.length <=0 || data[0].length != neurodes[0].length ||target.length == 0 || target[0].length != neurodes[depth-1].length ) {throw new IllegalArgumentException();}int round = 1;boolean convergence = false;while ( round <= maxIteration && ! convergence ) {double rate = 0.2;//1.0/round;// learn ratedouble delta = 0.0;for ( int r=0; r<data.length; r++) {double res = trainWithOneSample(data[r], target[r], rate);delta = (delta<res)?res:delta;}convergence = (delta<threshold);round++;System.out.printf(" %d round of train, delta is %f %n", round-1, delta);}return true;}/** * Train the neuronetwork with one entry of sample data *  * @param data: an vector represent one entry of taining sample * @param target: an vector represent class label of the training sample * @param rate: learn rate * @return: maximum detla of weights */private double trainWithOneSample(double[] data, double[] target, double rate) {calculateOutput(data);// calculate error for layer n-1for ( int j=0; j<neurodes[depth-1].length; j++ ) {double output = neurodes[depth-1][j].output;neurodes[depth-1][j].err = output*(1-output)*(target[j]-output);}// calculate error for hidden layers n-2 ... 1for ( int d=depth-2; d>0; d-- ) {for ( int j=0; j<neurodes[d].length; j++ ) {double error = 0.0;for ( int k=0; k<neurodes[d+1].length; k++ ) {error += neurodes[d+1][k].err*weights[d][j][k];}double output = neurodes[d][j].output;neurodes[d][j].err = output*(1-output)*error;}}double maxDelta = 0.0;// update weightsfor ( int d=0; d<depth-1; d++ ) {for ( int i=0; i<neurodes[d].length; i++ ) {for ( int j=0; j<neurodes[d+1].length; j++ ) {double delta = neurodes[d][i].output*neurodes[d+1][j].err;weights[d][i][j] += rate*delta;if ( maxDelta < Math.abs(delta) ) {maxDelta = Math.abs(delta);}}}}// update thetafor ( int d=1; d<depth; d++ ) {for ( int j=0; j<neurodes[d].length; j++ ) {neurodes[d][j].theta += rate*neurodes[d][j].err;}}return maxDelta;}}

测试:

public class TestMain {public static double[][][] generateData(int m) {double[][][] res = new double[2][][];double[][] data = new double[m*m][2];double[][] label = new double[m*m][3];for ( int i=0; i<m; i++ ) {double x = i/(m-1.0);for ( int j=0; j<m; j++ ) {double y = j/(m-1.0);data[i*m+j][0] = x;data[i*m+j][1] = y;label[i*m+j][0] = label[i*m+j][1] = label[i*m+j][2] = 0; if ( y > 4.0*(x-0.5)*(x-0.5) ) {label[i*m+j][0] = 1;} else if ( x < 0.5 ) {label[i*m+j][1] = 1;} else {label[i*m+j][2] = 1;}}}res[0] = data;res[1] = label;return res;}public static int calculateLabel(double x, double y) {if ( y > 4.0*(x-0.5)*(x-0.5) ) {return 0;} else if ( x < 0.5 ) {return 1;} else {return 2;}}/** * @param args */public static void main(String[] args) {int[] num = { 2, 3, 3 };int m = 10, n = 3;NeuroNetwork inst = new NeuroNetwork(num.length, num);double[][][] trainData = generateData(m);inst.train(trainData[0], trainData[1], 1000000, 0.001, 0.8);int t=50, success = 0;double[][][] testData = generateData(t);for ( int i=0; i<t*t; i++ ) {int res = inst.predict(testData[0][i], testData[1][i]);int ans = calculateLabel(testData[0][i][0], testData[0][i][1]);if ( res == ans ) {success ++;}System.out.printf("<%f, %f> : %d %b%n",testData[0][i][0],testData[0][i][1],res,res==ans);}System.out.printf("Accuracy rate is %f%n", (success+0.0)/(t*t));}}


原创粉丝点击