机器学习之梯度下降法

来源:互联网 发布:人工智能 风口 编辑:程序博客网 时间:2024/05/16 05:51

方向导数

如图,对于函数f(x,y),函数的增量与pp’两点距离之比在p’沿l趋于p时,则为函数在点p沿l方向的方向导数。记为fl=limρ0f(x+Δx,y+Δy)f(x,y)ρ,其中ρ=(Δx)2+(Δy)2。方向导数为函数f沿某方向的变化速率。

这里写图片描述

而且有如下定理:
fl=fxcosΘ+fysinΘ

梯度

梯度是一个向量,它的方向与取得最大方向导数的方向一致,梯度的模为方向导数的最大值。某点的梯度记为
gradf(x,y)=fxi+fyj

梯度的方向就是函数f在此点增长最快的方向,梯度的模为方向导数的最大值。

梯度下降

同样还是在线性回归中,假设函数为
h(x)=θ0+θ1x
那么损失函数为
J(θ)=12ni=1(h(xi)yi)2
要求最小损失,分别对θ0θ1求偏导,
J(θ)θj=θj12ni=1(h(xi)yi)2
=ni=1(h(xi)yi)θj(nj=0(θjxj)iyi)
=ni=1(h(xi)yi)xij
那么不断通过下面方式更新θ即可以逼近最低点。
θj:=θjαni=1(h(xi)yi)xij

其中α为learning rate,表现为下降的步伐。它不能太大也不能太小,太大会overshoot,太小则下降慢。通常可以尝试0.001、0.003、0.01、0.03、0.1、0.3。

这就好比站在一座山的某个位置上,往周围各个方向跨出相同步幅的一步,能够最快下降的方向就是梯度。这个方向是梯度的反方向。
这里写图片描述

这里写图片描述

这里写图片描述

另外,初始点的不同可能会出现局部最优解的情况,如下图:
这里写图片描述

伪代码

repeat until convergence{

θj:=θjαni=1(h(xi)yi)xij

for every j

}

随机梯度下降

样本太大时,每次更新都需要遍历整个样本,效率较低,这是就引入了随机梯度下降。

它可以每次只用一个样本来更新,免去了遍历整个样本。

伪代码如下

repeat until convergence{

i=random(1,n)

θj:=θjα(h(xi)yi)xij

for every j

}

另外与随机梯度下降类似的还有小批量梯度下降,它是折中的方式,取了所有样本中的一小部分。

代码实现

import numpy as npimport matplotlib.pyplot as pltlearning_rate = 0.0005theta = [0.7, 0.8, 0.9]loss = 100times = 100ite = 0expectation = 0.0001x_train = [[1, 2], [2, 1], [2, 3], [3, 5], [1, 3], [4, 2], [7, 3], [4, 5], [11, 3], [8, 7]]y_train = [7, 8, 10, 14, 8, 13, 20, 16, 28, 26]loss_array = np.zeros(times)def h(x):    return theta[0]*x[0]+theta[1]*x[1]+theta[2]while loss > expectation and ite < times:    loss = 0    sum_theta0 = 0    sum_theta1 = 0    sum_theta2 = 0    for x, y in zip(x_train, y_train):        sum_theta0 += (h(x) - y) * x[0]        sum_theta1 += (h(x) - y) * x[1]        sum_theta2 += (h(x) - y)    theta[0] -= learning_rate * sum_theta0    theta[1] -= learning_rate * sum_theta1    theta[2] -= learning_rate * sum_theta2    loss = 0    for x, y in zip(x_train, y_train):        loss += pow((h(x) - y), 2)    loss_array[ite] = loss    ite += 1plt.plot(loss_array)plt.show()

这里写图片描述

========广告时间========

鄙人的新书《Tomcat内核设计剖析》已经在京东销售了,有需要的朋友可以到 https://item.jd.com/12185360.html 进行预定。感谢各位朋友。

为什么写《Tomcat内核设计剖析》

=========================

欢迎关注:
这里写图片描述

2 0
原创粉丝点击