Tensorflow: Linear Regression
来源:互联网 发布:数据库类型有哪些 编辑:程序博客网 时间:2024/05/22 00:26
modified from https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/logistic_regression_mnist.ipynb
toy dataset
import numpy as npimport osfrom scipy.misc import imread, imresizeimport matplotlib.pyplot as pltimport pprintfrom sklearn.datasets import load_bostondef toy_dataset(n): w, b = 0.7, -1 noise_var = 0.001 x = np.random.random((1, n)) gt = w * x + b label = gt + np.sqrt(noise_var)*np.random.randn(1, n) return x, gt, labeln_samples = 100data, gt, label = toy_dataset(n_samples)print (" Type of 'train_X' is ", type(data))print (" Shape of 'train_X' is %s" % (data.shape,))print (" Type of 'train_Y' is ", type(label))print (" Shape of 'train_Y' is %s" % (label.shape,))plt.figure(1)plt.plot(x[0, :], gt[0, :], 'ro', label='Original data')plt.plot(x[0, :], y[0, :], 'bo', label='Training data')plt.axis('equal')plt.legend(loc='lower right')plt.show()
linear regression
X = tf.placeholder(tf.float32, name='input')Y = tf.placeholder(tf.float32, name='output')w = tf.Variable(np.random.randn(), name='weight')b = tf.Variable(np.random.randn(), name='bias')act = tf.add(tf.mul(X,w), b)lr = 0.001loss = tf.reduce_mean(tf.pow(act-Y,2))# optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)optimizer = tf.train.RMSPropOptimizer(lr, 0.9).minimize(loss)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)training_epochs = 5000snapshot = 50loss_cache = []for epoch in xrange(training_epochs): for x, y in zip(data, label): out = sess.run([optimizer, loss, w, b], feed_dict={X:x, Y:y}) loss_cache.append(out[1]) if epoch % snapshot == 0: print '[Epoch: %d] loss: %.4f, w: %.4f, b: %.4f' % (epoch, out[1], out[2], out[3])w_new = sess.run(w)b_new = sess.run(b)y_pre = x * w_new + b_newprint y_pre.shapeplt.figure(2)plt.plot(data[0,:], gt[0,:], 'ro', label='Ground Truth')plt.plot(data[0,:], label[0,:], 'bo', label='Training Label')plt.plot(data[0,:], y_pre, 'k-', label='Fitted Line')plt.axis('equal')plt.legend(loc='lower right')plt.show()
plt.figure(3)plt.plot(range(training_epochs), loss_cache, 'b-', label='loss')plt.legend(loc='upper right')plt.show()
0 0
- Tensorflow: Linear Regression
- TensorFlow 实现Linear Regression
- Tensorflow入门:Linear Regression
- 【tensorflow】linear regression
- tensorflow实现Linear Regression
- Linear regression of multiple features in Tensorflow
- Linear regression
- linear regression
- linear regression
- linear regression
- Linear Regression
- linear regression
- linear regression
- linear regression
- linear regression
- Linear Regression
- Linear Regression
- Linear Regression
- TCP、IP详解(转)
- debian LSB开机启动项脚本学习笔记
- 【SQL】 sql中的日期比较
- HTTP连接管理
- 正弦波放大电路与移相电路设计
- Tensorflow: Linear Regression
- SDAU练习四 1003
- 嵌入式lab1
- pl/sql 运算符和表达式
- android事件分发(三)重要的函数requestDisallowInterceptTouchEvent
- java 实现单片机与PC串口通讯
- Cassandra 的 cql入门使用
- //毫米到英寸,单位换算
- 【Memcashed0】学习小结