tensorflow1.1/线性回归

来源:互联网 发布:百度php面试题 编辑:程序博客网 时间:2024/06/01 08:47

环境:tensorflow1.1 python3 matplotlib2.02

tensorflow 1.1和之前版本有了很大的改动,在构建神经网络方面代码量减少了很多,matplotlib2.02在画图上也比之前好看了许多

#coding:utf-8import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltx = np.linspace(-1,1,500)[:,np.newaxis] #列向量noise = np.random.normal(0,0.1,x.shape)y = np.power(x,3) + noisexs = tf.placeholder(tf.float32,x.shape)ys = tf.placeholder(tf.float32,y.shape)#构建神经网络#输入,输出神经元个数,激活函数l1 = tf.layers.dense(xs,20,tf.nn.relu) #输出10个神经元的隐藏层,激活函数reluoutput = tf.layers.dense(l1,1) #输入l1,输出神经元个数1#定义均方误差loss#tf.losses.mean_squared_errorloss = tf.losses.mean_squared_error(ys,output) #均方误差#定义优化器optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.4).minimize(loss) #数据量较小调大learning_rate使其学习加快with tf.Session() as sess:    init = tf.global_variables_initializer()    sess.run(init)    plt.ion() #打开交互模式    for step in range(100):        _,c = sess.run([optimizer,loss],feed_dict={xs:x,ys:y})        prediction = sess.run(output,feed_dict={xs:x}) #计算预测值        if step % 5 == 0:            #可以用clf()来清空当前图像,用cla()来清空当前坐标            plt.clf()#清空当前图像            plt.scatter(x,y)            plt.plot(x,prediction,'c-',lw='5')            plt.text(0,0.5,'cost=%.4f' % c,fontdict={'size':15,'color':'red'}) #添加text,位置在坐标轴0,0.5处            plt.pause(0.1) #暂停0.1s    plt.ioff() #关闭交互模式    plt.show()

结果

这里写图片描述

原创粉丝点击