Tensorflow: Linear Regression

来源:互联网 发布:数据库类型有哪些 编辑:程序博客网 时间:2024/05/22 00:26

modified from https://github.com/sjchoi86/tensorflow-101/blob/master/notebooks/logistic_regression_mnist.ipynb

toy dataset

import numpy as npimport osfrom scipy.misc import imread, imresizeimport matplotlib.pyplot as pltimport pprintfrom sklearn.datasets import load_bostondef toy_dataset(n):    w, b = 0.7, -1    noise_var = 0.001    x   = np.random.random((1, n))    gt = w * x + b    label  = gt + np.sqrt(noise_var)*np.random.randn(1, n)    return x, gt, labeln_samples = 100data, gt, label = toy_dataset(n_samples)print (" Type of 'train_X' is ", type(data))print (" Shape of 'train_X' is %s" % (data.shape,))print (" Type of 'train_Y' is ", type(label))print (" Shape of 'train_Y' is %s" % (label.shape,))plt.figure(1)plt.plot(x[0, :], gt[0, :], 'ro', label='Original data')plt.plot(x[0, :], y[0, :], 'bo',  label='Training data')plt.axis('equal')plt.legend(loc='lower right')plt.show()

这里写图片描述

linear regression

X = tf.placeholder(tf.float32, name='input')Y = tf.placeholder(tf.float32, name='output')w = tf.Variable(np.random.randn(), name='weight')b = tf.Variable(np.random.randn(), name='bias')act = tf.add(tf.mul(X,w), b)lr = 0.001loss = tf.reduce_mean(tf.pow(act-Y,2))# optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)optimizer = tf.train.RMSPropOptimizer(lr, 0.9).minimize(loss)init = tf.initialize_all_variables()sess = tf.Session()sess.run(init)training_epochs = 5000snapshot = 50loss_cache = []for epoch in xrange(training_epochs):    for x, y in zip(data, label):        out = sess.run([optimizer, loss, w, b], feed_dict={X:x, Y:y})        loss_cache.append(out[1])    if epoch % snapshot == 0:        print '[Epoch: %d] loss: %.4f, w: %.4f, b: %.4f' % (epoch, out[1], out[2], out[3])w_new = sess.run(w)b_new = sess.run(b)y_pre = x * w_new + b_newprint y_pre.shapeplt.figure(2)plt.plot(data[0,:], gt[0,:], 'ro', label='Ground Truth')plt.plot(data[0,:], label[0,:], 'bo', label='Training Label')plt.plot(data[0,:], y_pre, 'k-', label='Fitted Line')plt.axis('equal')plt.legend(loc='lower right')plt.show()

fitted line

plt.figure(3)plt.plot(range(training_epochs), loss_cache, 'b-', label='loss')plt.legend(loc='upper right')plt.show()

loss_RMSProp

0 0
原创粉丝点击