JAVA使用JOONE实现神经网络的官网例子

来源:互联网 发布:2016淘宝店好做吗 编辑:程序博客网 时间:2024/05/22 05:14
JAVA使用JOONE实现神经网络的官网例子
import org.joone.engine.*;import org.joone.engine.learning.*;import org.joone.io.*;import org.joone.net.*;/** * JOONE神经网络的测试学习类 * */public class XOR_using_NeuralNet implements NeuralNetListener{private NeuralNet nnet = null;private MemoryInputSynapse inputSynapse, desiredOutputSynapse;LinearLayer input;SigmoidLayer hidden, output;boolean singleThreadMode = true;/** * XOR input */private double[][] inputArray = new double[][]{{ 0.0, 0.0 },{ 0.0, 1.0 },{ 1.0, 0.0 },{ 1.0, 1.0 }};/** * XOR desired output */private double[][] desiredOutputArray = new double[][]{{ 0.0 },{ 1.0 },{ 1.0 },{ 1.0 }};/** * @param args the command line arguments */public static void main(String args[]){XOR_using_NeuralNet xor = new XOR_using_NeuralNet();xor.initNeuralNet();xor.train();xor.interrogate();}/** * Method declaration */public void train(){// set the inputsinputSynapse.setInputArray(inputArray);inputSynapse.setAdvancedColumnSelector(" 1,2 ");// set the desired outputsdesiredOutputSynapse.setInputArray(desiredOutputArray);desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");// get the monitor object to train or feed forwardMonitor monitor = nnet.getMonitor();// set the monitor parameters创建监视器对象并且设置学习参数monitor.setLearningRate(0.8);monitor.setMomentum(0.3);monitor.setTrainingPatterns(inputArray.length);monitor.setTotCicles(5000);monitor.setLearning(true);long initms = System.currentTimeMillis();// Run the network in single-thread, synchronized modennet.getMonitor().setSingleThreadMode(singleThreadMode);nnet.go(true);System.out.println(" Total time=  "+ (System.currentTimeMillis() - initms) + "  ms ");}private void interrogate(){double[][] inputArray = new double[][]{{ 1.0, 1.0 }};// set the inputsinputSynapse.setInputArray(inputArray);inputSynapse.setAdvancedColumnSelector(" 1,2 ");Monitor monitor = nnet.getMonitor();monitor.setTrainingPatterns(4);monitor.setTotCicles(1);monitor.setLearning(false);MemoryOutputSynapse memOut = new MemoryOutputSynapse();// set the output synapse to write the output of the netif (nnet != null){nnet.addOutputSynapse(memOut);System.out.println(nnet.check());nnet.getMonitor().setSingleThreadMode(singleThreadMode);nnet.go();for (int i = 0; i < 4; i++){double[] pattern = memOut.getNextPattern();System.out.println(" Output pattern # " + (i + 1) + " = " + pattern[0]);}System.out.println(" Interrogating Finished ");}}/** * Method declaration */protected void initNeuralNet(){// First create the three layers首先,创造这三个层input = new LinearLayer();hidden = new SigmoidLayer();output = new SigmoidLayer();// set the dimensions of the layers指定在每一层中的"行"号。该"行"号对应于这一层中的神经原的数目。input.setRows(2);hidden.setRows(3);output.setRows(1);//每一层被赋于一个名字input.setLayerName(" L.input ");hidden.setLayerName(" L.hidden ");output.setLayerName(" L.output ");// Now create the two SynapsesFullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn.输入-> 隐蔽的连接 */FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn.隐蔽-> 输出连接 */// Connect the input layer whit the hidden layer联接输入层到隐蔽层input.addOutputSynapse(synapse_IH);hidden.addInputSynapse(synapse_IH);// Connect the hidden layer whit the output layer联接隐蔽层到输出层hidden.addOutputSynapse(synapse_HO);output.addInputSynapse(synapse_HO);// the input to the neural netinputSynapse = new MemoryInputSynapse();input.addInputSynapse(inputSynapse);// The Trainer and its desired outputdesiredOutputSynapse = new MemoryInputSynapse();TeachingSynapse trainer = new TeachingSynapse();trainer.setDesired(desiredOutputSynapse);// Now we add this structure to a NeuralNet objectnnet = new NeuralNet();nnet.addLayer(input, NeuralNet.INPUT_LAYER);nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);nnet.setTeacher(trainer);output.addOutputSynapse(trainer);nnet.addNeuralNetListener(this);}public void cicleTerminated(NeuralNetEvent e){}public void errorChanged(NeuralNetEvent e){Monitor mon = (Monitor) e.getSource();if (mon.getCurrentCicle() % 100 == 0)System.out.println(" Epoch:  "+ (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "+ mon.getGlobalError());}public void netStarted(NeuralNetEvent e){Monitor mon = (Monitor) e.getSource();System.out.print(" Network started for  ");if (mon.isLearning())System.out.println(" training. ");elseSystem.out.println(" interrogation. ");}public void netStopped(NeuralNetEvent e){Monitor mon = (Monitor) e.getSource();System.out.println(" Network stopped. Last RMSE= "+ mon.getGlobalError());}public void netStoppedError(NeuralNetEvent e, String error){System.out.println(" Network stopped due the following error:  "+ error);}}
SEO外链

1 0
原创粉丝点击