优化算法——梯度下降法

来源:互联网 发布:数据流量 英文 编辑:程序博客网 时间:2024/06/06 17:33

一、优化算法概述

    优化算法所要求解的是一个问题的最优解或者近似最优解。现实生活中有很多的最优化问题,如最短路径问题,如组合优化问题等等,同样,也存在很多求解这些优化问题的方法和思路,如梯度下降方法。

    机器学习在近年来得到了迅速的发展,越来越多的机器学习算法被提出,同样越来越多的问题利用机器学习算法得到解决。优化算法是机器学习算法中使用到的一种求解方法。在机器学习,我们需要寻找输入特征与标签之间的映射关系,在寻找这样的映射关系时,有一条重要的原则就是使得寻找到的映射结果与原始标签之间的误差最小。机器学习问题归纳起来就是把一个学习的问题转化为优化的问题,机器学习算法的本质就是如何对问题抽象建模,使一个学习的问题变为一个可求解的优化问题。

    优化的算法有很多种,从最基本的梯度下降法到现在的一些启发式算法,如遗传算法(GA),差分演化算法(DE),粒子群算法(PSO)和人工蜂群算法(ABC)。

二、梯度下降法

1、基本概念

    梯度下降法又被称为最速下降法(Steepest descend method),其理论基础是梯度的概念。梯度与方向导数的关系为:梯度的方向与取得最大方向导数值的方向一致,而梯度的模就是函数在该点的方向导数的最大值。对于一个无约束的优化问题: ,例如

如图,在处的切线。显然在处函数取得最小值。沿着梯度的方向是下降速度最快的方向。具体的过程为:初始时,任取的值,如取,则对应的。利用梯度下降法,其中为学习率,可以取固定常数。如取,则,对应的,类似的,对应的。算法终止的判断准则是:,其中是一个指定的阈值。梯度的更新公式为:

2、算法流程

梯度下降法的流程:

1、初始化:随机选取取值范围内的任意数

2、循环操作:

       计算梯度;

       修改新的变量;

       判断是否达到终止:如果前后两次的函数值差的绝对值小于阈值,则跳出循环;否则继续;

3、输出最终结果

    与梯度下降法对应的是被称为梯度上升的算法,主要的区别就是在梯度的方向上,一个方向是下降最快的方向,相反的就是梯度上升最快的方法。主要用来求解最大值问题:。梯度的更新公式为:

下面以为例,给出一下的Java程序:

[java] view plain copy
  1. public class SteepestDescend {  
  2.     public static double alpha = 0.5;// 迭代步长  
  3.     public static double e = 0.00001;// 收敛精度  
  4.   
  5.     public double x0;  
  6.     public double y0;  
  7.   
  8.     public double getY(double x) {  
  9.         return (x * x - 3 * x + 2);  
  10.     }  
  11.   
  12.     public double getDerivative(double x) {  
  13.         return (2 * x - 3);  
  14.     }  
  15.   
  16.     public void init() {  
  17.         x0 = 0;  
  18.         y0 = this.getY(x0);  
  19.     }  
  20.   
  21.     public double getSteepestDescend() {  
  22.         double min = 0;  
  23.         double x = x0;  
  24.         double y = y0;  
  25.         double y1;  
  26.         double temp = 0;  
  27.         /* 
  28.          * 做梯度运算 
  29.          */  
  30.         while (true) {  
  31.             temp = this.getDerivative(x);  
  32.             x = x - alpha * temp;  
  33.             y1 = this.getY(x);  
  34.             if (Math.abs(y1 - y) <= e) {  
  35.                 break;  
  36.             }  
  37.             y = y1;  
  38.             min = y;  
  39.         }  
  40.         return min;  
  41.     }  
  42. }  

主函数:

[java] view plain copy
  1. public class TestMain {  
  2.     public static void main(String args[]) {  
  3.         double min;  
  4.         SteepestDescend sd = new SteepestDescend();  
  5.         sd.init();  
  6.         min = sd.getSteepestDescend();  
  7.         System.out.println("最小值:"+ min );  
  8.     }  
  9.   
  10. }  
  11. 结果为 -0.25
原创粉丝点击