神经网络-向后传播java实现

来源:互联网 发布:淘宝类目在线查询 编辑:程序博客网 时间:2024/05/19 17:59

1、Point.java

package com.network;


/**
 * <p>本类描述: </p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-29 上午09:15:55
 */
public class Point {
    
    public Point() {}
    
    public Point(Integer name, Double in) {
        this.name = name;
        this.in = in;
    }
    
    public Point(Integer name, Double o, Double err, Double in, Double out) {
        this.name = name;
        this.o = o;
        this.err = err;
        this.in = in;
        this.out = out;
    }
    
    public Point(Integer name, Double o, Double err, Double in, Double out, Integer classify) {
        this.name = name;
        this.o = o;
        this.err = err;
        this.in = in;
        this.out = out;
        this.classify = classify;
    }


    /** 点的名字 */
    private Integer name;
    
    /** 点的偏倚值 */
    private Double o;
    
    /** 点的初始值 */
    private Double err;
    
    /** 点的净输入 */
    private Double in;
    
    /** 点的输出 */
    private Double out;
    
    /** 点的类别 */
    private Integer classify = 0;


    public Integer getClassify() {
        return classify;
    }


    public void setClassify(Integer classify) {
        this.classify = classify;
    }


    public Double getIn() {
        return in;
    }


    public void setIn(Double in) {
        this.in = in;
    }


    public Double getOut() {
        return out;
    }


    public void setOut(Double out) {
        this.out = out;
    }


    public Integer getName() {
        return name;
    }


    public void setName(Integer name) {
        this.name = name;
    }


    public Double getO() {
        return o;
    }


    public void setO(Double o) {
        this.o = o;
    }


    public Double getErr() {
        return err;
    }


    public void setErr(Double err) {
        this.err = err;
    }


}

2、Edge.java

package com.network;


/**
 * <p>本类描述: </p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-29 上午09:11:42
 */
public class Edge {


    /** 边的起点 */
    private Point start;
    
    /** 边的终点 */
    private Point end;
    
    /** 边的权重 */
    private Double weight;
    
    public Edge() {}


    public Edge(Point start, Point end, Double weight) {
        this.start = start;
        this.end = end;
        this.weight = weight;
    }


    public Point getStart() {
        return start;
    }


    public void setStart(Point start) {
        this.start = start;
    }


    public Point getEnd() {
        return end;
    }


    public void setEnd(Point end) {
        this.end = end;
    }


    public Double getWeight() {
        return weight;
    }


    public void setWeight(Double weight) {
        this.weight = weight;
    }
}

3、NeuralNetwork.java

package com.network;


import java.util.ArrayList;
import java.util.List;


/**
 * <p>
 *      本类描述: 
 *          利用向后传播的神经网络方法学习,产生可预测类别的模型,本类假定隐藏层数为1(两层神经网络)
 *          隐藏层包含的单元可以指定,输出层的单元也可以指定
 *  </p>
 *  <p>
 *      主要步骤: 
 *          步骤1: 初始化网络中的权重和偏倚
 *          步骤2: 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
 *          步骤3: 逐层向后计算输出层和隐藏层的每个单元的误差
 *          步骤4: 更新所有权重和偏倚
 *  </p>
 *  <p>
 *      其他说明:对未知元组X分类
 *          利用训练好的模型,计算每个单元的净输入和输出,如果每个类有一个输出节点,则具有最高输出值的
 *          节点决定X的预测类标号,如果只有一个输出节点,则输出值大于或等于0.5可以视为正类,而值小于0.5
 *          可以视为负类。
 *  </p>
 * @author Wang Haiyang
 * @date 2015-6-26 下午04:10:10
 */
public class NeuralNetwork {
    
    /** 学习率 */
    public static final Double study = 0.9D;
    
    /** 样本集 */
    public static List<ArrayList<Point>> samples = new ArrayList<ArrayList<Point>>();
    
    /** 隐藏层点集 */
    public static List<Point> hideLayers = new ArrayList<Point>();
    
    /** 输出层点集 */
    public static List<Point> outLayers = new ArrayList<Point>();
    
    /** 边集 */
    public static List<Edge> edges = new ArrayList<Edge>();


    public static void main(String[] args) {
       
        // 准备初始化参数
        init();
        
        // 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
        compute();
        
        // 打印
        display();
    }


    /**
     * 方法描述:打印
     */
    private static void display() {
        System.out.println("权重:");
        for (int i = 0; i < edges.size(); i++) {
            Edge edge = edges.get(i);
            System.out.println("w" + edge.getStart().getName() + edge.getEnd().getName() + ": " + edge.getWeight());
        }
        System.out.println("隐藏层偏倚:");
        for (int i = 0; i < hideLayers.size(); i++) {
            Point point = hideLayers.get(i);
            System.out.println("O" + point.getName() + ": " + point.getO());
        }
        System.out.println("输出层偏倚:");
        for (int i = 0; i < outLayers.size(); i++) {
            Point point = outLayers.get(i);
            System.out.println("O" + point.getName() + ": " + point.getO());
        }
    }


    /**
     * 方法描述:训练模型
     */
    private static void compute() {
        for (ArrayList<Point> points : samples) {
           
            // 计算输入层每个单元的输出
            for (Point point1 : points) {
                point1.setOut(point1.getIn());
            }
            
            // 计算隐藏层的每个单元的净输入和输出
            getInOut(hideLayers, points);
            
            // 计算输出层的每个单元的净输入和输出
            getInOut(outLayers, points);
            
            // 计算输出层的误差
            for (Point point2 : outLayers) {
                Double out = point2.getOut();
                Double err = out * (1 - out) * (point2.getClassify() - out);
                point2.setErr(err);
            }
            
            // 计算隐藏层的误差
            for (Point hide : hideLayers) {
                Double sum = 0D;
                for (Point out : outLayers) {
                    sum += out.getErr() * (getWeight(hide, out));
                }
                Double out = hide.getOut();
                Double err = out * (1 - out) * sum;
                hide.setErr(err);
            }
            
            // 更新所有权重
            for (Edge edge : edges) {
                Double weight = edge.getWeight() + study * edge.getEnd().getErr() * edge.getStart().getOut();
                edge.setWeight(weight);
            }
            
            // 更新隐藏层偏倚
            updateO(hideLayers);
            
            // 更新输出层偏倚
            updateO(outLayers);
        }
    }


    /**
     * 方法描述:准备初始化参数
     */
    private static void init() {
        ArrayList<Point> inLayers = new ArrayList<Point>();
        Point p1 = new Point(1, 1D);
        inLayers.add(p1);
        Point p2 = new Point(2, 0D);
        inLayers.add(p2);
        Point p3 = new Point(3, 1D);
        inLayers.add(p3);
        samples.add(inLayers);
        
        Point p4 = new Point(4, -0.4D, 0D, 0D, 0D);
        hideLayers.add(p4);
        Point p5 = new Point(5, 0.2D, 0D, 0D, 0D);
        hideLayers.add(p5);
        
        Point p6 = new Point(6, 0.1D, 0D, 0D, 0D, 1);
        outLayers.add(p6);
        
        Edge edge1 = new Edge(p1, p4, 0.2D);
        Edge edge2 = new Edge(p1, p5, -0.3D);
        Edge edge3 = new Edge(p2, p4, 0.4D);
        Edge edge4 = new Edge(p2, p5, 0.1D);
        Edge edge5 = new Edge(p3, p4, -0.5D);
        Edge edge6 = new Edge(p3, p5, 0.2D);
        Edge edge7 = new Edge(p4, p6, -0.3D);
        Edge edge8 = new Edge(p5, p6, -0.2D);
        edges.add(edge1);
        edges.add(edge2);
        edges.add(edge3);
        edges.add(edge4);
        edges.add(edge5);
        edges.add(edge6);
        edges.add(edge7);
        edges.add(edge8);
    }


    /**
     * 方法描述:计算给定list的净输入和输出
     * @param layers
     * @param edges
     * @param points
     */
    private static void updateO(List<Point> layers) {
        for (Point hide : layers) {
            Double o = hide.getO() + study * hide.getErr();
            hide.setO(o);
        }
    }
    
    /**
     * 方法描述:计算给定list的净输入和输出
     * @param layers
     * @param edges
     * @param points
     */
    private static void getInOut(List<Point> layers, ArrayList<Point> points) {
        for (int i = 0; i< layers.size(); i++) {
            Point hide = layers.get(i);
            Double in = 0D;
            Double out = 0D;
            Double sum = 0D;
            for (Point point3 : points) {
                sum += getWeight(point3, hide) * point3.getOut();
            }
            in = sum + hide.getO();
            hide.setIn(in);
            out = 1.0 / (1 + Math.pow(Math.E, (-in)));
            hide.setOut(out);
        }
    }


    /**
     * 方法描述:根据给定的两个点得到这条边的权重
     * @param point3
     * @param hide
     * @return
     */
    private static Double getWeight(Point point3, Point hide) {
        Double weight = 0D;
        for (Edge edge : edges) {
            if (point3.getName() == edge.getStart().getName() && hide.getName() == edge.getEnd().getName()) {
                weight = edge.getWeight();
                break;
            }
        }
        return weight;
    }
}

0 0