【JAVA】批量梯度下降

来源:互联网 发布:金山数据恢复大师vip 编辑:程序博客网 时间:2024/05/17 23:17
/** * 批量梯度下降 */public class BatchGradient {public void batchGradientDescent() {double inputDataMatrix[][] = { { 1, 4 }, { 2, 5 }, { 5, 1 }, { 4, 2 } }; // X输入double expectResult[] = { 19, 26, 19, 20 }; // 期望输出值double w[] = { 2, 6 }; // 权重参数 因为这里只涉及到两个变量 ,即X为两列输入double learningRate = 0.01;double loss = 100; // 损失值for (int i = 0; i < 100 && loss > 0.0001; i++) {double err_sum = 0;for (int j = 0; j < 4; j++) {double h = 0;for (int k = 0; k < 2; k++) {h = h + inputDataMatrix[j][k] * w[k];}err_sum = expectResult[j] - h;for (int k = 0; k < 2; k++) {w[k] = w[k] + learningRate * err_sum* inputDataMatrix[j][k]; // 权值每次改变的幅度,这个公式是通过梯度下降得到的}}System.out.println("此时的w权值为:" + "w0:" + w[0] + "---" + "w1:" + w[1]);double loss_sum = 0;for (int j = 0; j < 4; j++) {double sum = 0;for (int k = 0; k < 2; k++) {sum = sum + inputDataMatrix[j][k] * w[k];}loss_sum += (expectResult[j] - sum) * (expectResult[j] - sum);}System.out.println("loss:" + loss_sum);}}public static void main(String[] args) {BatchGradient bg = new BatchGradient();bg.batchGradientDescent();}}

refer to:http://blog.csdn.net/abcjennifer/article/details/7716281

http://www.xatarena.cn/javajswz/20130402/1313.html

0 0
原创粉丝点击