机器学习--线性回归(原理与例子)

来源:互联网 发布:买钢琴 知乎 编辑:程序博客网 时间:2024/06/04 01:30

这里写图片描述
注意:由于x是已知的,y也是已知的,式子看着是关于x的函数,由于thet(希腊字母打不出来就用发音词thet表示,后文一样)未知,所以我们换个视角,可将将式子看成是关于thet 的函数。
而事实上,机器学习中学习的就是关于thet 的模型。
机器学习的目的就是学习一个模型,使得经过这个模型运算产生的值与真实值之间越接近越好。换句话说就是两者之间误差越小越好。如何衡量这个误差呢?我们采用最小二乘法。
即,我们得到我们的目标函数(书中常见为损失函数,为统一叫法,后文均称为损失函数):
这里写图片描述
我们采用梯度下降法,算法步骤:
这里写图片描述
梯度下降法的矩阵方式如下:
这里写图片描述
下面结合tensorflow框架来实现线性回归:

#导入库import matplotlib.pyplot as pltimport numpy as npimport tensorflow as tffrom sklearn import datasetsfrom tensorflow.python.framework import ops

创建计算图、加载数据

ops.reset_default_graph()#启动一个graph sessionsess=tf.Session()
#数据来源于Scikit Learn内建的数据集iris=datasets.load_iris()x_vals=np.array([x[3] for x in iris.data])y_vals=np.array([y[0] for y in iris.data])
#初始化placeholders(占位符)、创建变量()x_data=tf.placeholder(shape=[None,1],dtype=tf.float32)y_target=tf.placeholder(shape=[None,1],dtype=tf.float32)#为了与上文原理中的命名统一(这里用Thet来表示权值参数,thet0表示截距)Thet=tf.Variable(tf.random_normal(shape=[1,1]))thet0=tf.Variable(tf.random_normal(shape=[1,1]))#声明一个线性模型、损失函数、优化器(选择梯度下降法)linear_output=tf.add(tf.matmul(x_data,Thet),thet0)loss=tf.reduce_mean(tf.square(y_target-linear_output))opt=tf.train.GradientDescentOptimizer(0.05)train_step=opt.minimize(loss)#初始化变量init=tf.global_variables_initializer()sess.run(init)#训练批量大小20,迭代100次batch_size=20loss_vec=[]for i in range(100):    random_index=np.random.choice(len(x_vals),size=batch_size)    random_x=np.transpose([x_vals[random_index]])    random_y=np.transpose([y_vals[random_index]])    sess.run(train_step,feed_dict={x_data:random_x,y_target:random_y})    temp_loss=sess.run(loss,feed_dict={x_data:random_x,y_target:random_y})    loss_vec.append(temp_loss)[slope]=sess.run(Thet)[intercept]=sess.run(thet0)fit_line=[]for i in x_vals:    fit_line.append(slope*i+intercept)

展示效果

plt.plot(x_vals,y_vals,'o',label=u'数据点')plt.plot(x_vals,fit_line,'r-',label=u'拟合效果',linewidth=2)plt.legend(loc='upper left')plt.title(u'x与y的关系')plt.xlabel(u'x')plt.ylabel(u'y')plt.show()

这里写图片描述
损失误差

#损失误差plt.plot(loss_vec,'k-')plt.title('L2 loss')plt.xlabel(u'迭代次数')plt.ylabel(u'L2 损失误差')plt.show()

这里写图片描述

参考:
《TensorFlow Machine Learning cookbook》

更多技术干货,请关注下面二维码:
这里写图片描述

原创粉丝点击