TensorFlow学习笔记(十一)读取自己的数据进行训练

来源:互联网 发布:mysql索引类型 编辑:程序博客网 时间:2024/04/29 09:19

1. 线性关系

数据csv文件读取

x,y
1,2
4,5
6,11
3,6
4,7
5,12
7,13
10,21
11,23
24,50
45,89
50,101
55,111
60,123
70,139
80,164
85,171
90,192
95,190
100,199
200,401
1000,2000

代码:

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 15:43:41 2017

@author: ESRI
"""

# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 14:59:10 2017

@author: ESRI
"""

import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt



#读取数据
dataset = pd.read_csv('E:\\testData\\network.csv')

#查看描述信息
print(dataset.describe())
#查看前5行
print(dataset.head())
#查看数据形状
print(dataset.shape)

#分别得到
X_data = dataset['x'].as_matrix(columns=None).reshape(-1,1)
#print(X_data)
Y_data = dataset['y'].as_matrix(columns=None).reshape(-1,1)


#添加一层网络
def add_layer(inputs, in_size, out_size, activation_function=None):
    # add one more layer and return the output of this layer
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b)
    return outputs


#归一化
def normalize(train):
    mean, std = train.mean(), train.std()
    train = (train - mean) / std
    return train

xs = tf.placeholder(tf.float32)
ys = tf.placeholder(tf.float32)

#归一化处理数据
X = normalize(X_data)
Y = normalize(Y_data)




#3层网络
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

#计算loss
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
                     reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
#init = tf.initialize_all_variables()
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#结果可视化
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(X, Y)
plt.ion()
plt.show()


for i in range(8000):
    # training
    sess.run(train_step, feed_dict={xs: X, ys: Y})
    if i % 50 == 0:
        print(sess.run(loss, feed_dict={xs: X, ys: Y}))
        try:
            ax.lines.remove(lines[0])
        except Exception:
            pass
        prediction_value = sess.run(prediction, feed_dict={xs: X})
        # plot the prediction
        lines = ax.plot(X, prediction_value, 'r-', lw=5)
        plt.pause(0.1)

 结果:    
                    x            y
count    22.000000    22.000000
mean     91.136364   183.181818
std     208.740051   417.486314
min       1.000000     2.000000
25%       6.250000    12.250000
50%      47.500000    95.000000
75%      83.750000   169.250000
max    1000.000000  2000.000000
      x    y
0     1    2
1     4    5
2     6   11
3     3    6
4     4    7
5     5   12
6     7   13
7    10   21
8    11   23
9    24   50
10   45   89
11   50  101
12   55  111
13   60  123
14   70  139
15   80  164
16   85  171
17   90  192
18   95  190
19  100  199
(22, 2)
[[-0.4419732 ]
 [-0.42726305]
 [-0.41745628]
 ...,
 [ 0.04346181]
 [ 0.53380021]
 [ 4.45650743]]


11.6106
0.00154685
0.00107705
0.000622838
0.000468779
0.000346998
0.000274857
0.00016539
9.39608e-05
6.02521e-05
4.41742e-05
3.47886e-05
3.02667e-05
2.81042e-05
2.73301e-05
2.69677e-05
2.67462e-05
2.66131e-05
2.6452e-05
2.63586e-05
2.63102e-05
2.61975e-05
2.61691e-05
2.61784e-05
2.61712e-05
2.61596e-05
2.61267e-05
2.61323e-05
2.61504e-05
2.61072e-05
2.61337e-05
2.61305e-05
2.60892e-05
2.60815e-05
2.6096e-05
2.60919e-05
2.60685e-05
2.60606e-05
2.60774e-05
2.61023e-05
2.60717e-05
2.60601e-05
2.60832e-05
2.60474e-05
2.60752e-05
2.60568e-05
2.60328e-05
2.60716e-05
2.60527e-05
2.60288e-05
2.60224e-05
2.60488e-05
2.60549e-05
2.60573e-05
2.60576e-05
2.60556e-05
2.60509e-05
2.60434e-05
2.60333e-05
2.60186e-05
2.60025e-05
2.60154e-05
2.60487e-05
2.60329e-05
2.59924e-05
2.60066e-05
2.60364e-05
2.60053e-05
2.60045e-05
2.60256e-05
2.5987e-05
2.60303e-05
2.59782e-05
2.603e-05
2.59753e-05
2.60242e-05
2.59781e-05
2.60142e-05
2.59865e-05
2.59966e-05
2.6021e-05
2.59726e-05
2.59987e-05
2.6012e-05
2.59699e-05
2.59885e-05
2.60072e-05
2.59776e-05
2.59591e-05
2.59867e-05
2.59993e-05
2.59841e-05
2.59637e-05
2.59506e-05
2.59757e-05
2.59872e-05
2.59941e-05
2.5992e-05
2.59636e-05
2.59547e-05
2.59475e-05
2.59412e-05
2.59377e-05
2.59612e-05
2.59653e-05
2.59678e-05
2.59692e-05
2.59695e-05
2.59691e-05
2.59679e-05
2.59662e-05
2.59643e-05
2.59615e-05
2.59585e-05
2.59546e-05
2.59497e-05
2.5937e-05
2.59175e-05
2.59209e-05
2.59248e-05
2.59291e-05
2.5935e-05
2.59458e-05
2.59611e-05
2.59533e-05
2.59444e-05
2.59316e-05
2.59077e-05
2.59154e-05
2.59242e-05
2.59549e-05
2.59445e-05
2.59325e-05
2.58975e-05
2.59079e-05
2.59221e-05
2.59423e-05
2.59288e-05
2.58909e-05
2.5903e-05
2.59232e-05
2.59314e-05
2.59115e-05
2.58939e-05
2.59115e-05
2.59273e-05
2.59065e-05
2.589e-05
2.59133e-05
2.59185e-05
2.58759e-05
2.58929e-05
2.59227e-05
2.59028e-05
2.58816e-05
2.59242e-05
2.59048e-05
2.58738e-05
2.59005e-05
2.59029e-05     
  

阅读全文
0 0
原创粉丝点击