简单的Tensorflow(4):线性模型分析

来源:互联网 发布:淘宝店铺一键发布宝贝 编辑:程序博客网 时间:2024/06/06 17:31

首先明确我们要做的事:产生一个y = k * x + b的模型,然后模拟得到k和b的值。


使用numpy产生100个随机点x_data = np.random.rand(100)

使用numpy产生随机点的波动值y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))


构造这个线性模型

b = tf.Variable(0.)

k = tf.Variable(0.)

y = k * x_data + b


构造一个损失函数,使用二次代价函数,loss = tf.reduce_mean(tf.square(y_data - y)),理论的东西来源一个公式:


然后对loss函数求偏导,就可以找到临界点。


定义一个梯度下降法来训练的优化器,这里选择梯度下降optimizer = tf.train.GradientDescentOptimizer(0.2)

梯度下降的问题如果不理解也可以用求导数类比,就好像是二次函数求导可以找出函数递减到最低点的地方。


优化最小代价函数train = optimizer.minimize(loss),就是对损失函数最小化,不理解的话也可以类比为物理中一个有阻尼效应的弹簧振子的振幅,慢慢变小。


全部代码:(为了方便观察我打印出x_data和y_data的值

import tensorflow as tfimport numpy as np#使用numpy产生100个随机点x_data = np.random.rand(100)print(x_data)#使用numpy产生随机点的波动值y_data = x_data * 0.8 + (0.2 + 0.2*np.random.rand(100))print(y_data)print(y_data - (x_data * 0.8))#构造这个线性模型b = tf.Variable(0.)k = tf.Variable(0.)y = k * x_data + b#构造一个损失函数,使用二次代价函数loss = tf.reduce_mean(tf.square(y_data - y))#定义一个梯度下降法来训练的优化器,这里选择梯度下降optimizer = tf.train.GradientDescentOptimizer(0.2)#优化最小代价函数train = optimizer.minimize(loss)#初始化变量init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    for step in range(1001):        sess.run(train)        if step % 20 == 0:            print(step,sess.run([k,b]))


打印x_data的结果是:

[  7.44110800e-01   5.43004562e-01   1.40464610e-01   1.81464847e-01   6.39729239e-02   8.55839985e-01   8.57288253e-01   5.97521668e-01   2.78126750e-01   6.33273527e-01   7.36156567e-01   1.04051817e-01   8.88657006e-01   6.59313727e-01   5.90586641e-01   2.20599954e-01   4.35010023e-02   5.94651847e-01   1.34688295e-03   7.61364799e-01   2.40110638e-01   4.30085960e-01   9.29408073e-01   2.83296015e-01   9.66464568e-01   5.64793967e-01   2.89583567e-02   4.78895281e-01   6.55705166e-01   9.39931048e-01   9.22667199e-02   4.17453020e-01   8.18261854e-01   7.51253933e-01   8.34488361e-01   5.47085942e-01   2.20812245e-01   3.64683210e-01   2.72265898e-01   6.99546767e-01   5.17783069e-01   6.86348141e-01   8.26539294e-02   7.21142415e-01   5.25883873e-01   7.95120312e-01   5.53413821e-02   7.85643413e-01   6.00917521e-01   1.00047819e-01   2.25730457e-01   3.51933714e-01   3.02575025e-01   2.47871839e-02   7.68447228e-01   5.07069502e-01   2.80167969e-01   3.69932599e-01   2.12162593e-01   9.94378802e-01   1.15214455e-01   6.23622264e-02   7.36749874e-01   9.28610319e-01   5.70055090e-01   1.53292088e-01   3.41733503e-02   1.61350956e-01   1.39131378e-01   5.95169150e-01   6.73270412e-01   3.49607503e-01   6.76886118e-01   6.85209452e-01   2.25242789e-01   2.17403176e-01   4.76698614e-01   3.66161202e-01   8.71111576e-01   8.51866555e-01   2.42937403e-02   7.98989890e-01   1.56342752e-01   2.19081202e-01   2.68967638e-01   3.66510600e-01   5.63911914e-01   3.51834652e-01   3.53287522e-01   6.38729594e-01   8.64768242e-01   6.91163778e-01   4.01258574e-01   4.18343511e-01   1.04330897e-04   9.33082029e-02   3.62988647e-01   1.90686375e-01   1.37897024e-01   6.80795678e-01]


打印y_data的结果是:

[ 0.89765157  0.78318914  0.51019381  0.37265281  0.4484892   0.89407551  1.05607533  0.81662998  0.60433588  0.84483753  0.81586059  0.28435128  1.05519893  0.82464034  0.79314064  0.40467223  0.27353902  0.83144114  0.22773844  0.81602484  0.49884444  0.6495945   1.09878211  0.55309785  1.16746735  0.82575331  0.26850707  0.73932869  0.91451223  0.95979475  0.42978122  0.64909303  0.95775075  0.82245154  0.99671646  0.71105713  0.37859923  0.59687127  0.45301511  0.91201781  0.70358415  0.81903768  0.33778258  0.96624705  0.70700762  1.02552349  0.32549882  0.83025993  0.79845534  0.36227279  0.56073625  0.61983059  0.62606028  0.24095358  0.96142484  0.65170733  0.53331155  0.52622143  0.44052906  1.1054362  0.40243692  0.30179491  0.91216599  1.13082724  0.80231909  0.50341774  0.23161131  0.35105119  0.32236377  0.82267127  0.91461553  0.56597076  0.76393095  0.9008792   0.41224933  0.39897747  0.58644554  0.51947146  1.06847105  0.91237405  0.34874716  0.83959344  0.51669053  0.46544819  0.56874802  0.6768135   0.76004207  0.63262638  0.62986198  0.84711002  1.05019422  0.86545662  0.61937746  0.56008459  0.39665551  0.39971446  0.51636406  0.46776067  0.36391908  0.83299141]

打印(y_data - (x_data * 0.8))的结果是:

[ 0.30236293  0.34878549  0.39782212  0.22748093  0.39731086  0.20940352  0.37024473  0.33861265  0.38183448  0.33821871  0.22693534  0.20110982  0.34427332  0.29718936  0.32067133  0.22819227  0.23873822  0.35571967  0.22666093  0.206933    0.30675593  0.30552573  0.35525566  0.32646104  0.3942957   0.37391813  0.24534038  0.35621247  0.3899481   0.20784991  0.35596785  0.31513061  0.30314127  0.22144839  0.32912577  0.27338838  0.20194944  0.3051247   0.23520239  0.35238039  0.28935769  0.26995916  0.27165944  0.38933312  0.28630053  0.38942724  0.28122571  0.2017452  0.31772132  0.28223454  0.38015189  0.33828362  0.38400026  0.22112384  0.34666705  0.24605172  0.30917718  0.23027535  0.27079898  0.30993316  0.31026536  0.25190513  0.32276609  0.38793899  0.34627502  0.38078407  0.20427263  0.22197043  0.21105867  0.34653595  0.3759992   0.28628476  0.22242205  0.35271164  0.2320551   0.22505493  0.20508665  0.22654249  0.37158179  0.23088081  0.32931217  0.20040152  0.39161633  0.29018323  0.3535739   0.38360502  0.30891254  0.35115866  0.34723196  0.33612635  0.35837963  0.3125256   0.2983706   0.22540979  0.39657205  0.32506789  0.22597314  0.31521157  0.25360146  0.28835487]

最后k和b的结果是:

0 [0.14778559, 0.26584759]20 [0.51506805, 0.441048]40 [0.6482842, 0.37625444]60 [0.7260837, 0.33841407]80 [0.77151942, 0.31631491]100 [0.7980544, 0.30340874]120 [0.81355113, 0.29587138]140 [0.82260132, 0.29146951]160 [0.82788676, 0.28889877]180 [0.83097351, 0.28739744]200 [0.83277613, 0.28652063]220 [0.83382899, 0.28600857]240 [0.83444381, 0.28570953]260 [0.83480293, 0.28553486]280 [0.83501256, 0.2854329]300 [0.8351351, 0.2853733]320 [0.83520657, 0.28533852]340 [0.83524829, 0.28531826]360 [0.83527273, 0.28530636]380 [0.83528692, 0.28529945]400 [0.8352952, 0.2852954]420 [0.83530003, 0.28529307]440 [0.83530289, 0.28529167]460 [0.8353045, 0.2852909]480 [0.83530569, 0.28529033]500 [0.83530593, 0.28529021]520 [0.83530593, 0.28529021]540 [0.83530593, 0.28529021]560 [0.83530593, 0.28529021]580 [0.83530593, 0.28529021]600 [0.83530593, 0.28529021]620 [0.83530593, 0.28529021]640 [0.83530593, 0.28529021]660 [0.83530593, 0.28529021]680 [0.83530593, 0.28529021]700 [0.83530593, 0.28529021]720 [0.83530593, 0.28529021]740 [0.83530593, 0.28529021]760 [0.83530593, 0.28529021]780 [0.83530593, 0.28529021]800 [0.83530593, 0.28529021]820 [0.83530593, 0.28529021]840 [0.83530593, 0.28529021]860 [0.83530593, 0.28529021]880 [0.83530593, 0.28529021]900 [0.83530593, 0.28529021]920 [0.83530593, 0.28529021]940 [0.83530593, 0.28529021]960 [0.83530593, 0.28529021]980 [0.83530593, 0.28529021]1000 [0.83530593, 0.28529021]

其实从结果看,500次迭代之后就没有意义了。机器学习研究的一个重点就是,不要白白浪费计算机,浪费时间。