坐标下降法(坐标上升法)matlab程序

来源:互联网 发布:免费车销软件 编辑:程序博客网 时间:2024/05/21 11:36

起因

因为求解SVM的最牛算法SMO算法,使用的时坐标下降法的思路,所以学习一下这个算法。

方法

某函数包含多个自变量,需要求这个函数的最大或最小值时,可以应用此坐标下降法(最小值)或坐标上升法(最大值)。

其过程是,对每个自变量求偏导,交替的对每个自变量进行梯度下降(或上升法)。

案例

我们使用以下函数作为案例

z=f(x,y)=xe(x2+y2)

其函数图为:
这里写图片描述

可以看到,这个函数优一个最大值和最小值。

对其求关于x,y的偏导。

zx=e(x2+y2)+xe(x2+y2)(2x)=e(x2+y2)(12x2)

zy=xe(x2+y2)(2y)

这个问题比较简单,我们直接令上述两个导数等于0,就可以求出 x=+2/2
y=0

但很多实际问题是难以求解的,这时就应使用迭代的算法。
对每个参数交替使用梯度下降。

求解

1 先对x和y赋随机的初值。

2 随后对x使用梯度上升(我们要求最大值),a 为学习因子

x=x+azx

3 随后对y使用梯度上升
y=y+azy

4 重复2-3步,直至收敛。

代码

function z = f(x,y)    z = x.*exp(-x.^2 - y.^2);end
clc;clear;x = 0.2;y = 0.7; %初始值a = 0.2; %学习率xa = [];ya = [];za = [];oldx = 1;oldy = 1;while abs(oldx-x)+abs(oldy - y) > 1e-7    z= f(x,y);    dx = exp(-x.^2 - y.^2) -2*x*z;    dy = z.*(-2*y);    oldx=x;    oldy=y;    x = x + a*dx;    y = y + a*dy;    xa = [xa x];    ya = [ya y];    za = [za f(x,y)];    fprintf('x = %f ,y = %f , cha = %f\n',x,y,abs(oldx-x)+abs(oldy - y));endfw = -2:0.1:2;[x,y] = meshgrid(fw,fw);z = f(x,y);hold off;mesh(x,y,z);xlabel('x');ylabel('y');zlabel('z');pausehold on;plot3(xa,ya,za,'LineWidth',2);

结果

matlab输出

x = 0.308303 ,y = 0.667038 , cha = 0.141265x = 0.402698 ,y = 0.619101 , cha = 0.142332x = 0.481018 ,y = 0.561303 , cha = 0.136119x = 0.543232 ,y = 0.498771 , cha = 0.124746x = 0.590809 ,y = 0.435857 , cha = 0.110491x = 0.626028 ,y = 0.375773 , cha = 0.095303x = 0.651398 ,y = 0.320559 , cha = 0.080583x = 0.669268 ,y = 0.271252 , cha = 0.067178x = 0.681635 ,y = 0.228145 , cha = 0.055474x = 0.690075 ,y = 0.191040 , cha = 0.045545x = 0.695776 ,y = 0.159460 , cha = 0.037281x = 0.699596 ,y = 0.132798 , cha = 0.030482x = 0.702141 ,y = 0.110417 , cha = 0.024926x = 0.703830 ,y = 0.091705 , cha = 0.020401x = 0.704947 ,y = 0.076105 , cha = 0.016718x = 0.705685 ,y = 0.063124 , cha = 0.013718x = 0.706171 ,y = 0.052338 , cha = 0.011272x = 0.706492 ,y = 0.043384 , cha = 0.009274x = 0.706702 ,y = 0.035955 , cha = 0.007639x = 0.706841 ,y = 0.029795 , cha = 0.006299x = 0.706932 ,y = 0.024688 , cha = 0.005198x = 0.706992 ,y = 0.020455 , cha = 0.004293x = 0.707031 ,y = 0.016948 , cha = 0.003547x = 0.707057 ,y = 0.014041 , cha = 0.002932x = 0.707074 ,y = 0.011633 , cha = 0.002425x = 0.707085 ,y = 0.009637 , cha = 0.002007x = 0.707093 ,y = 0.007984 , cha = 0.001661x = 0.707098 ,y = 0.006615 , cha = 0.001374x = 0.707101 ,y = 0.005480 , cha = 0.001138x = 0.707103 ,y = 0.004540 , cha = 0.000942x = 0.707104 ,y = 0.003761 , cha = 0.000780x = 0.707105 ,y = 0.003116 , cha = 0.000646x = 0.707106 ,y = 0.002581 , cha = 0.000535x = 0.707106 ,y = 0.002138 , cha = 0.000443x = 0.707106 ,y = 0.001772 , cha = 0.000367x = 0.707106 ,y = 0.001468 , cha = 0.000304x = 0.707107 ,y = 0.001216 , cha = 0.000252x = 0.707107 ,y = 0.001007 , cha = 0.000209x = 0.707107 ,y = 0.000835 , cha = 0.000173x = 0.707107 ,y = 0.000691 , cha = 0.000143x = 0.707107 ,y = 0.000573 , cha = 0.000119x = 0.707107 ,y = 0.000474 , cha = 0.000098x = 0.707107 ,y = 0.000393 , cha = 0.000081x = 0.707107 ,y = 0.000326 , cha = 0.000067x = 0.707107 ,y = 0.000270 , cha = 0.000056x = 0.707107 ,y = 0.000224 , cha = 0.000046x = 0.707107 ,y = 0.000185 , cha = 0.000038x = 0.707107 ,y = 0.000153 , cha = 0.000032x = 0.707107 ,y = 0.000127 , cha = 0.000026x = 0.707107 ,y = 0.000105 , cha = 0.000022x = 0.707107 ,y = 0.000087 , cha = 0.000018x = 0.707107 ,y = 0.000072 , cha = 0.000015x = 0.707107 ,y = 0.000060 , cha = 0.000012x = 0.707107 ,y = 0.000050 , cha = 0.000010x = 0.707107 ,y = 0.000041 , cha = 0.000009x = 0.707107 ,y = 0.000034 , cha = 0.000007x = 0.707107 ,y = 0.000028 , cha = 0.000006x = 0.707107 ,y = 0.000023 , cha = 0.000005x = 0.707107 ,y = 0.000019 , cha = 0.000004x = 0.707107 ,y = 0.000016 , cha = 0.000003x = 0.707107 ,y = 0.000013 , cha = 0.000003x = 0.707107 ,y = 0.000011 , cha = 0.000002x = 0.707107 ,y = 0.000009 , cha = 0.000002x = 0.707107 ,y = 0.000008 , cha = 0.000002x = 0.707107 ,y = 0.000006 , cha = 0.000001x = 0.707107 ,y = 0.000005 , cha = 0.000001x = 0.707107 ,y = 0.000004 , cha = 0.000001x = 0.707107 ,y = 0.000004 , cha = 0.000001x = 0.707107 ,y = 0.000003 , cha = 0.000001x = 0.707107 ,y = 0.000002 , cha = 0.000001x = 0.707107 ,y = 0.000002 , cha = 0.000000x = 0.707107 ,y = 0.000002 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000001 , cha = 0.000000x = 0.707107 ,y = 0.000000 , cha = 0.000000

图形化显示

这里写图片描述

其中蓝线为x和y的变化曲线。