优化算法--牛顿法

来源:互联网 发布:约瑟夫环c语言循环链式 编辑:程序博客网 时间:2024/05/17 17:38

转载地址:http://blog.csdn.net/google19890102/article/details/41087931

一、牛顿法概述

    除了前面说的梯度下降法,牛顿法也是机器学习中用的比较多的一种优化算法。牛顿法的基本思想是利用迭代点处的一阶导数(梯度)和二阶导数(Hessen矩阵)对目标函数进行二次函数近似,然后把二次模型的极小点作为新的迭代点,并不断重复这一过程,直至求得满足精度的近似极小值。牛顿法的速度相当快,而且能高度逼近最优值。牛顿法分为基本的牛顿法和全局牛顿法。

二、基本牛顿法

1、基本牛顿法的原理

    基本牛顿法是一种是用导数的算法,它每一步的迭代方向都是沿着当前点函数值下降的方向。
    我们主要集中讨论在一维的情形,对于一个需要求解的优化函数,求函数的极值的问题可以转化为求导函数。对函数进行泰勒展开到二阶,得到

对上式求导并令其为0,则为

即得到

这就是牛顿法的更新公式。

2、基本牛顿法的流程

  1. 给定终止误差值,初始点,令
  2. 计算,若,则停止,输出
  3. 计算,并求解线性方程组得解
  4. ,并转2。

三、全局牛顿法

    牛顿法最突出的优点是收敛速度快,具有局部二阶收敛性,但是,基本牛顿法初始点需要足够“靠近”极小点,否则,有可能导致算法不收敛。这样就引入了全局牛顿法。

1、全局牛顿法的流程

  1. 给定终止误差值,初始点,令
  2. 计算,若,则停止,输出
  3. 计算,并求解线性方程组得解
  4. 是不满足下列不等式的最小非负整数
  5. ,并转2。

2、Armijo搜索

    全局牛顿法是基于Armijo的搜索,满足Armijo准则:
给定,令步长因子,其中是满足下列不等式的最小非负整数:

四、算法实现

    实验部分使用Java实现,需要优化的函数,最小值为

1、基本牛顿法Java实现

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. package org.algorithm.newtonmethod;  
  2.   
  3. /** 
  4.  * Newton法 
  5.  *  
  6.  * @author dell 
  7.  *  
  8.  */  
  9. public class NewtonMethod {  
  10.     private double originalX;// 初始点  
  11.     private double e;// 误差阈值  
  12.     private double maxCycle;// 最大循环次数  
  13.   
  14.     /** 
  15.      * 构造方法 
  16.      *  
  17.      * @param originalX初始值 
  18.      * @param e误差阈值 
  19.      * @param maxCycle最大循环次数 
  20.      */  
  21.     public NewtonMethod(double originalX, double e, double maxCycle) {  
  22.         this.setOriginalX(originalX);  
  23.         this.setE(e);  
  24.         this.setMaxCycle(maxCycle);  
  25.     }  
  26.   
  27.     // 一系列get和set方法  
  28.     public double getOriginalX() {  
  29.         return originalX;  
  30.     }  
  31.   
  32.     public void setOriginalX(double originalX) {  
  33.         this.originalX = originalX;  
  34.     }  
  35.   
  36.     public double getE() {  
  37.         return e;  
  38.     }  
  39.   
  40.     public void setE(double e) {  
  41.         this.e = e;  
  42.     }  
  43.   
  44.     public double getMaxCycle() {  
  45.         return maxCycle;  
  46.     }  
  47.   
  48.     public void setMaxCycle(double maxCycle) {  
  49.         this.maxCycle = maxCycle;  
  50.     }  
  51.   
  52.     /** 
  53.      * 原始函数 
  54.      *  
  55.      * @param x变量 
  56.      * @return 原始函数的值 
  57.      */  
  58.     public double getOriginal(double x) {  
  59.         return x * x - 3 * x + 2;  
  60.     }  
  61.   
  62.     /** 
  63.      * 一次导函数 
  64.      *  
  65.      * @param x变量 
  66.      * @return 一次导函数的值 
  67.      */  
  68.     public double getOneDerivative(double x) {  
  69.         return 2 * x - 3;  
  70.     }  
  71.   
  72.     /** 
  73.      * 二次导函数 
  74.      *  
  75.      * @param x变量 
  76.      * @return 二次导函数的值 
  77.      */  
  78.     public double getTwoDerivative(double x) {  
  79.         return 2;  
  80.     }  
  81.   
  82.     /** 
  83.      * 利用牛顿法求解 
  84.      *  
  85.      * @return 
  86.      */  
  87.     public double getNewtonMin() {  
  88.         double x = this.getOriginalX();  
  89.         double y = 0;  
  90.         double k = 1;  
  91.         // 更新公式  
  92.         while (k <= this.getMaxCycle()) {  
  93.             y = this.getOriginal(x);  
  94.             double one = this.getOneDerivative(x);  
  95.             if (Math.abs(one) <= e) {  
  96.                 break;  
  97.             }  
  98.             double two = this.getTwoDerivative(x);  
  99.             x = x - one / two;  
  100.             k++;  
  101.         }  
  102.         return y;  
  103.     }  
  104.   
  105. }  

2、全局牛顿法Java实现

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. package org.algorithm.newtonmethod;  
  2.   
  3. /** 
  4.  * 全局牛顿法 
  5.  *  
  6.  * @author dell 
  7.  *  
  8.  */  
  9. public class GlobalNewtonMethod {  
  10.     private double originalX;  
  11.     private double delta;  
  12.     private double sigma;  
  13.     private double e;  
  14.     private double maxCycle;  
  15.   
  16.     public GlobalNewtonMethod(double originalX, double delta, double sigma,  
  17.             double e, double maxCycle) {  
  18.         this.setOriginalX(originalX);  
  19.         this.setDelta(delta);  
  20.         this.setSigma(sigma);  
  21.         this.setE(e);  
  22.         this.setMaxCycle(maxCycle);  
  23.     }  
  24.   
  25.     public double getOriginalX() {  
  26.         return originalX;  
  27.     }  
  28.   
  29.     public void setOriginalX(double originalX) {  
  30.         this.originalX = originalX;  
  31.     }  
  32.   
  33.     public double getDelta() {  
  34.         return delta;  
  35.     }  
  36.   
  37.     public void setDelta(double delta) {  
  38.         this.delta = delta;  
  39.     }  
  40.   
  41.     public double getSigma() {  
  42.         return sigma;  
  43.     }  
  44.   
  45.     public void setSigma(double sigma) {  
  46.         this.sigma = sigma;  
  47.     }  
  48.   
  49.     public double getE() {  
  50.         return e;  
  51.     }  
  52.   
  53.     public void setE(double e) {  
  54.         this.e = e;  
  55.     }  
  56.   
  57.     public double getMaxCycle() {  
  58.         return maxCycle;  
  59.     }  
  60.   
  61.     public void setMaxCycle(double maxCycle) {  
  62.         this.maxCycle = maxCycle;  
  63.     }  
  64.   
  65.     /** 
  66.      * 原始函数 
  67.      *  
  68.      * @param x变量 
  69.      * @return 原始函数的值 
  70.      */  
  71.     public double getOriginal(double x) {  
  72.         return x * x - 3 * x + 2;  
  73.     }  
  74.   
  75.     /** 
  76.      * 一次导函数 
  77.      *  
  78.      * @param x变量 
  79.      * @return 一次导函数的值 
  80.      */  
  81.     public double getOneDerivative(double x) {  
  82.         return 2 * x - 3;  
  83.     }  
  84.   
  85.     /** 
  86.      * 二次导函数 
  87.      *  
  88.      * @param x变量 
  89.      * @return 二次导函数的值 
  90.      */  
  91.     public double getTwoDerivative(double x) {  
  92.         return 2;  
  93.     }  
  94.   
  95.     /** 
  96.      * 利用牛顿法求解 
  97.      *  
  98.      * @return 
  99.      */  
  100.     public double getGlobalNewtonMin() {  
  101.         double x = this.getOriginalX();  
  102.         double y = 0;  
  103.         double k = 1;  
  104.         // 更新公式  
  105.         while (k <= this.getMaxCycle()) {  
  106.             y = this.getOriginal(x);  
  107.             double one = this.getOneDerivative(x);  
  108.             if (Math.abs(one) <= e) {  
  109.                 break;  
  110.             }  
  111.             double two = this.getTwoDerivative(x);  
  112.             double dk = -one / two;// 搜索的方向  
  113.             double m = 0;  
  114.             double mk = 0;  
  115.             while (m < 20) {  
  116.                 double left = this.getOriginal(x + Math.pow(this.getDelta(), m)  
  117.                         * dk);  
  118.                 double right = this.getOriginal(x) + this.getSigma()  
  119.                         * Math.pow(this.getDelta(), m)  
  120.                         * this.getOneDerivative(x) * dk;  
  121.                 if (left <= right) {  
  122.                     mk = m;  
  123.                     break;  
  124.                 }  
  125.                 m++;  
  126.             }  
  127.             x = x + Math.pow(this.getDelta(), mk)*dk;  
  128.             k++;  
  129.         }  
  130.         return y;  
  131.     }  
  132. }  

3、主函数

[java] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. package org.algorithm.newtonmethod;  
  2.   
  3. /** 
  4.  * 测试函数 
  5.  * @author dell 
  6.  * 
  7.  */  
  8. public class TestNewton {  
  9.     public static void main(String args[]) {  
  10.         NewtonMethod newton = new NewtonMethod(00.00001100);  
  11.         System.out.println("基本牛顿法求解:" + newton.getNewtonMin());  
  12.   
  13.         GlobalNewtonMethod gNewton = new GlobalNewtonMethod(00.550.4,  
  14.                 0.00001100);  
  15.         System.out.println("全局牛顿法求解:" + gNewton.getGlobalNewtonMin());  
  16.     }  
  17. }  
0 0
原创粉丝点击