python的matplotlib模块实现tensorflow结果可视化

来源:互联网 发布:淘宝人气值是什么 编辑:程序博客网 时间:2024/06/07 19:20

python的matplotlib模块实现tensorflow结果可视化。

'''# tensorflow输出结果可视化(三层全连接神经网络)# python3.6.1/tensorflow1.2.1'''# coding=utf-8# 导入模块import tensorflow as tfimport numpy as np# import matplotlib.pyplot as plt # python数据图形化工具# 添加神经网络层的函数def add_layer(inputs, in_size, out_size, activation_function=None):    # 权值    Weights = tf.Variable(tf.random_normal([in_size, out_size]))    # 偏置项    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)    # 加权和    Wx_plus_b = tf.matmul(inputs, Weights) + biases    # 激活函数    if activation_function is None:        outputs = Wx_plus_b    else:        outputs = activation_function(Wx_plus_b)    return outputs#=== 生成真实数据# 输入数据x_data = np.linspace(-1, 1, 300)[:, np.newaxis]# 300个从-1~1之间等差的行向量 # 噪点noise = np.random.normal(0, 0.05, x_data.shape)# 输出数据y_data = np.square(x_data) - 0.5 + noise#plt.scatter(x_data, y_data)#plt.show()#=== 生成占位符xs = tf.placeholder(tf.float32, [None, 1])ys = tf.placeholder(tf.float32, [None, 1])#=== add hidden layer 隐含层l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)#=== add output layer 输出层prediction = add_layer(l1, 10, 1, activation_function=None)#=== 损失函数和梯度下降loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)#=== 生成会话和初始化变量  init = tf.global_variables_initializer() sess = tf.Session()sess.run(init)#tensorflow >= 0.12 的版本不支持# tf.initialize_all_variables()初始化向量}#=== plot真是数据fig = plt.figure()# 生成图片框ax = fig.add_subplot(1,1,1)# 图片框分割成一块,编号为111ax.scatter(x_data, y_data)# 以点的形式显示出真实数据#plt.ion()# 连续的打印#plt.show()# 输出图片#=== 训练神经挖网络for i in range(1000):    sess.run(train_step, feed_dict={xs: x_data, ys: y_data})    if i % 50 == 0:# 每50步输出一次        try:            ax.lines.remove(lines[0])# 删除lines中第一条数据        except Exception:            pass        prediction_value =             sess.run(prediction, feed_dict={xs: x_data})        lines =             ax.plot(x_data, prediction_value, 'r-', lw=5)            # 以线的形式输出预测数据,线的颜色为红色,线宽为5
原创粉丝点击