线性回归与梯度下降

来源:互联网 发布:罗宇回忆录 知乎 编辑:程序博客网 时间:2024/04/30 00:45

1. Regression

 

回归(Regression)是机器学习应用中常见的问题。一般在建模解决问题的时候都会试图通过一连串的输入向量{X1,X2,...,Xn}推测出对应的输出{Y1,Y2,...,Yn}。这个使用X与Y之间的映射关系,通过输入自变量X的取值算出Y的对应值的过程,当Y连续时,就是regression;当Y离散时,就是classification。比如,当我们输入一个人的很多信息,推测这个人的月薪时,因为月薪可以看作是较为连续的数组成的集合,因此这个过程可以被认为是regression。而如果是推测其职业,因为各个职业的关系是离散的,因此这个过程就是classification。这篇文章主要讨论regression。

 

2. Linear Regression

 

Linear Regression is the most popular and basic regression. In the junior period, most of us touched this function:

 

 

This is the most basic type of linear function. For example, the height and weight of human may obey a kind of relationship. Provided the height is H cm and the weight is W kg, there is:

 

 

To satisfy the data set (W, H), we need to find a (k, b) which can make this model map the W and H accurately. This is linear regression.

 

Nevertheless, there is no (k, b) satisfying all (W1, H1), (W2, H2), ... (Wn, Hn). We have to find the most proper (k, b) making this model reach the result as near as possible. That means, for all this data, given a W, this model can predict the H and very near to the fact. In this example, the difference between the truth value of height y and the H we calculate is the error of this linear regression model. We hope that the total sum of error on all data is as few as possible. Therefore, there is a function named loss function:

 

 

Generally, many dimensions are being considered while a model is building so that the linear regression is always as:

 

 


Because both the THETA and X are vectors, we have:

 

Or we can use y stand for the fact value:

 

 

This EPSILON is the difference between the value of predict and fact, which is normal distribution and the mean value is 0. Therefore, a normal distribution function can be obtained and the pdf is:

 

Because the y is iid, the joint pdf of y is the product of all margin pdf. Using the function, we have the likelihood function:

 


3. Gradient Descent

 

Gradient Descent is a popular method to modify the parameters in machine learning area. The formula we obtained which is shown as:

 


 

can be also understood as:

 

 

To let this J be the minimum, we need to find the stationary point of J, which could be obtained when:

 

 

This is the result that we calculate the partial derivative of THETA. Sometimes there is a LAMBDA added to avoid overfitting:

 

 

Gradient descent is to update the THETA making the error minimum. The update method can be always as:

 

 

In this equation, THETA is updated by the gradient. ALPHA is the rate of the updating.

 

4. Normalization

 

To avoid over fitting, there is always a normalization in the formula:

 

 

This is l2-norm, also named RIDGE, which supposed that the THETA is under Gaussian Distribution. Normalization is aiming to reduce the weight of high order term. Another popular normalization is l1-norm, also known as LASSO. Because of the sparsity, LASSO is de facto used to choose feature all the time. LASSO can be shown as:

 

 

Nevertheless, LASSO can choose features but RIDGE performs better so that we sometimes combine them as Elastic Net:

 

 

5. Batch/Stochastic Gradient Descent

 

Since the THETA is updated by gradient descent as:

 

 

and we have:

 

 

Therefore, for each term in THETA, the partial derivative can be shown as:

 

 

When we try to train the THETA, batch gradient descent(BGD) can be a proper way which update the THETA by a batch of input data, being understood as:

 

While THETA is not convergence {

}

 

Obviously, BGD need a batch of data which indicates that BGD cannot be used to do online learning. Stochastic gradient descent(SGD) performs better on this case. SGD do not need to scan each data so that if the m is very large, SGD can save time compared to BGD. Meanwhile, SGD make the gradient descent for each data so that it can solve the online learning problem. SGD can be shown as:

 

For i=1:m {


}

 

In fact, sometimes we use mini-batch SGD, which means that we modify the THETA for few data input.


 

 

 

 

 

 

 

 

原创粉丝点击