tensorflow训练线性函数

来源:互联网 发布:mysql redis 特点区别 编辑:程序博客网 时间:2024/06/06 09:04

tensorflow一个简单的例子,训练一个 y = 0.1x+0.3的线性函数。

主要学习目的是了解tensorflow训练模型的过程和模式,对以后训练其他模型有个大概的了解。

tensorflow训练模型大致有以下几个步骤:

1、数据处理(加载数据、创建数据)

2、构建训练模型或网络

3、损失函数计算与参数优化

4、创建Session

5、迭代训练

环境: win7 64位 tensorflow-gpu1.3 python3.5

"""this code is only for python 3+"""import tensorflow as tfimport numpy as np#create datax_data = np.random.rand(100).astype(np.float32)y_data = x_data*0.1+0.3### create tensorflow structure start ####产生-1到1之间的均匀分布的类型为float32的一维张量Weights = tf.Variable(tf.random_uniform([1],minval=-1.0,maxval=1.0,dtype=np.float32))biases = tf.Variable(tf.zeros([1]))y = Weights*x_data+biases#损失函数  对平方和求均值loss = tf.reduce_mean(tf.square(y - y_data))#采用梯度下降算法对模型参数进行优化参数为0-1optimizer = tf.train.GradientDescentOptimizer(0.5)#计算损失方程的最小值  可以和上一步链式表示train = optimizer.minimize(loss)#初始化所有的变量init = tf.global_variables_initializer()### create tensorflow structure end #### 创建一个Sessionsess = tf.Session()# 执行初始化sess.run(init)#迭代训练training_steps = 801for step in range(training_steps):    sess.run(train)    #打印训练过程的递减情况    if step % 20 == 0:        print(step,sess.run(Weights),sess.run(biases))
0 [ 0.80305016] [-0.14744911]20 [ 0.28564903] [ 0.19576499]40 [ 0.1475025] [ 0.27332914]60 [ 0.11215457] [ 0.29317567]80 [ 0.10311003] [ 0.29825383]100 [ 0.10079577] [ 0.29955322]120 [ 0.1002036] [ 0.29988569]140 [ 0.1000521] [ 0.29997078]160 [ 0.10001334] [ 0.29999253]180 [ 0.10000343] [ 0.2999981]200 [ 0.10000085] [ 0.29999954]220 [ 0.10000022] [ 0.29999989]240 [ 0.1000001] [ 0.29999995]260 [ 0.1000001] [ 0.29999995]280 [ 0.1000001] [ 0.29999995]300 [ 0.1000001] [ 0.29999995]320 [ 0.1000001] [ 0.29999995]340 [ 0.1000001] [ 0.29999995]360 [ 0.1000001] [ 0.29999995]380 [ 0.1000001] [ 0.29999995]400 [ 0.1000001] [ 0.29999995]420 [ 0.1000001] [ 0.29999995]440 [ 0.1000001] [ 0.29999995]460 [ 0.1000001] [ 0.29999995]480 [ 0.1000001] [ 0.29999995]500 [ 0.1000001] [ 0.29999995]520 [ 0.1000001] [ 0.29999995]540 [ 0.1000001] [ 0.29999995]560 [ 0.1000001] [ 0.29999995]580 [ 0.1000001] [ 0.29999995]600 [ 0.1000001] [ 0.29999995]620 [ 0.1000001] [ 0.29999995]640 [ 0.1000001] [ 0.29999995]660 [ 0.1000001] [ 0.29999995]680 [ 0.1000001] [ 0.29999995]700 [ 0.1000001] [ 0.29999995]720 [ 0.1000001] [ 0.29999995]740 [ 0.1000001] [ 0.29999995]760 [ 0.1000001] [ 0.29999995]780 [ 0.1000001] [ 0.29999995]800 [ 0.1000001] [ 0.29999995]

最后发现训练到240步时,参数已经获得了很好的结果。

原创粉丝点击