线性回归
来源:互联网 发布:网络流行组合字 编辑:程序博客网 时间:2024/06/06 02:55
为了深入理解Gradient Descent算法,写了如下代码。
在y = 2x直线上生成随机高斯噪声
# -*- coding: utf-8 -*-"""Created on Wed Aug 30 15:27:51 2017@author: liuxy"""import numpy as npimport matplotlib.pyplot as pltdef gen_data(size): x = np.arange(0, size, 1) e = np.random.normal(0, 3, size) y = 2*x + e return [x, y]def compute_gradient_full(data, w): X = data[0] Y = data[1] N = len(X) g = np.sum(2*X*(X*w - Y))/N return g def compute_gradient_SGD(data, w): X = data[0] Y = data[1] idx = np.random.randint(0, len(X)-1) d = X[idx] t = Y[idx] g = 2*d*(d*w - t) return gdef compute_gradient_miniBatch(data, w): X = data[0] Y = data[1] N = 16 X_b = [] Y_b = [] for i in range(N): idx = np.random.randint(0, len(X)-1) X_b.append(X[idx]) Y_b.append(Y[idx]) X_ba = np.array(X_b) Y_ba = np.array(Y_b) g = np.sum(2*X_ba*(X_ba*w - Y_ba))/N return gdef Optimizer(data, w, learning_rate, num_iterator, method, Wts): for i in range(num_iterator): g = 0 if ('full' == method): g = compute_gradient_full(data, w) if ('mini' == method): g = compute_gradient_miniBatch(data, w) if ('sgd' == method): g = compute_gradient_SGD(data, w) w = w - learning_rate * g Wts.append(w) data = gen_data(100)#plt.scatter(data[0], data[1])lr = 0.000020w = 6num = 100Weights_full = []Weights_mini = []Weights_sgd = []Weights_full.append(w)Weights_mini.append(w)Weights_sgd.append(w)Optimizer(data, w, lr, num, 'full', Weights_full)Optimizer(data, w, lr, num, 'mini', Weights_mini)Optimizer(data, w, lr, num, 'sgd', Weights_sgd)plt.plot(np.arange(0,num+1), Weights_full) plt.plot(np.arange(0,num+1), Weights_mini) plt.plot(np.arange(0,num+1), Weights_sgd)
权重变化, full, mini batch, sgd
阅读全文
0 0
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 线性回归
- 双飞翼布局
- c程序的执行过程
- 一款比较好的java和scala开发工具,界面和eclipse一样
- 七 mysql连接池
- 【IntelliJ IDEA java-web 初学之容易遇到的问题及解决办法】
- 线性回归
- 数据库执行sql的大致流程——如何优化
- 剑指Offer-56
- SDUTOJ 3443 找老乡
- 微信小程序周报(第十三期)-极乐商店(store.dreawer.com)出品
- ios常见错误—— -[_NSString absoluteURL](请求网络图片)
- Spring各个版本新特性
- ORA-16714: the value of property ArchiveLagTarget is inconsistent with the database setting
- java.lang.IllegalStateException: Failed to load ApplicationContext