NG 线性回归

来源:互联网 发布:php tools for vs2015 编辑:程序博客网 时间:2024/05/16 05:06
# 线性回归 标准方程法import numpy as npimport pandas as pd A = pd.read_table('LR.txt',header=None,usecols = (0,1,2))X = np.mat(A.iloc[:,0:2])y = np.mat(A.iloc[:,2]).T# 求转置: X.T  求逆:X.ITRAN_X = X.T params = (TRAN_X*X).I*TRAN_X*yprint(params)# 梯度下降法import numpy as npimport pandas as pd import matplotlib.pyplot as pltA = pd.read_table('LR.txt',header=None,usecols = (0,1,2))X = np.mat(A.iloc[:,0:2])y = np.mat(A.iloc[:,2]).Tm = A.index.max()params = np.mat([1.0,1.0]).T #求参数params_1要用到的cost = []time = 50for i in range(time):error_col = X*params-y #误差列cost.append(sum( np.multiply(error_col,error_col) )[0,0]/(2*m)) #cost function 作为监控画图params[0] = params[0] - (0.1/m)*sum(error_col) #0.1为学习率 可以取0.001 0.003 0.01 0.03 。。。0.1params[1] = params[1] - (0.1/m)*sum(np.multiply(error_col,np.mat(A.iloc[:,1]).T))print(params)#可视化# 逼近plt.subplot(221),plt.plot(range(time),cost,'r--*'),plt.title('desence'),plt.xlabel('Itera Times'),plt.ylabel('Cost function')# 拟合函数x=np.linspace(0,1,200)plt.subplot(222),plt.plot(A.iloc[:,1],A.iloc[:,2],'b*'),plt.title('Hypothesis function'),plt.plot(x,params[1,0]*x+params[0,0],'r-')plt.show()



原创粉丝点击