tensorflow 莫烦 二次函数弥拟合(四)

来源:互联网 发布:淘宝闲置物品能退货吗 编辑:程序博客网 时间:2024/04/28 04:07
# -*- coding: utf-8 -*-"""Created on Wed Apr 19 22:24:49 2017@author: user"""import tensorflow as tfimport numpy as npdef add_layer(inputs,in_size,out_size,activation_function=None):    Weights=tf.Variable(tf.random_normal([in_size,out_size]))       biases=tf.Variable(tf.zeros([1,out_size])+0.1)     Wx_plus_b=tf.matmul(inputs,Weights)+biases    if activation_function is None:        outputs=Wx_plus_b    else:        outputs=activation_function(Wx_plus_b)      return outputs          x_data=np.linspace(-1,1,300,dtype=np.float32)[:,np.newaxis] noise=np.random.normal(0,0.05,x_data.shape).astype(np.float32)y_data=np.square(x_data)-0.5+noise   xs=tf.placeholder(tf.float32,[None,1])ys=tf.placeholder(tf.float32,[None,1])l1=add_layer(x_data,1,10,activation_function=tf.nn.relu)prediction =  add_layer(l1,10,1,activation_function=None)loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction),reduction_indices=[1]))train_step=tf.train.GradientDescentOptimizer(0.01).minimize(loss)init=tf.initialize_all_variables()sess=tf.Session()sess.run(init)for i in range(1000):    sess.run(train_step,feed_dict={xs:x_data,ys:y_data})    if i%50==0:        print(sess.run(loss,feed_dict={xs:x_data,ys:y_data}))
0 0
原创粉丝点击