关于TensorFlow中的多图(Multiple Graphs)

来源:互联网 发布:hibernate访问数据库 编辑:程序博客网 时间:2024/06/05 16:35
关于TensorFlow中的多图(Multiple Graphs) - CSDN博客

http://blog.csdn.net/aiya_xiazai/article/details/58701092


关于TensorFlow中的多图(Multiple Graphs)

原创 2017年02月28日 22:46:54

一、摘要

   TensorFlow中的图(Graph)是众多操作(Ops)的集合,它描述了具体的操作类型与各操作之间的关联。在实际应用中,我们可以直接把图理解为神经网络(Neural Network)结构的程序化描述。TensorFlow中的会话(Session)则实现图中所有操作,使得数据(Tensor类型)在图中流动(Flow)起来。平常学习或使用TensorFlow中,我们基本是构筑一个图,开启一个会话,然后run。但最近由于工作需要,我探索了多图(Multiple Graphs)方式。本文主要任务是简单记录学习过程,以备复阅。如果能给看官提供一丝帮助,那也是极好。

二、多图实现

(交代一下平台版本:PyCharm Community Edition 2016.3.2、Python3.5.2、tensorflow0.12.1)
片段一:多图的建立与默认图

[python] view plain copy
  1. import tensorflow as tf  
  2. import numpy as np  
  3. g1 = tf.Graph()  #建立图1  
  4. g2 = tf.Graph()  #建立图2  
  5. print('tf.get_default_graph()=',tf.get_default_graph())#获取默认图,并显示基本信息  
  6. print('g1                    =',g1)  
  7. print('g2                    =',g2)  
  8. print('------------------------------------')  
片段一运行结果:
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
g2              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FD68>
------------------------------------
从结果地址可以看出:默认图自动存在,手动建立的图与默认图完全不同


片段二:建立各图下的具体操作(Op)
[python] view plain copy
  1. with g1.as_default():#在with模块中,g1作为默认图  
  2.     x_data = np.random.rand(100).astype(np.float32)#定义图中具体操作  
  3.     y_data = x_data * 0.1 + 0.3  
  4.     W = tf.Variable(tf.random_uniform([1], -1.01.0))  
  5.     b = tf.Variable(tf.zeros([1]))  
  6.     y = W * x_data + b  
  7.     loss = tf.reduce_mean(tf.square(y - y_data))  
  8.     print('num-of-trainable_variables=', len(tf.trainable_variables()), ' num-of-global_variables=',len(tf.global_variables()))#统计变量个数  
  9.     print('g1                    =',g1)  
  10.     print('tf.get_default_graph()=',tf.get_default_graph())  
  11. print('tf.get_default_graph()=',tf.get_default_graph())  
  12. W2 = tf.Variable(tf.random_uniform([1], -1.01.0))#with模块外定义操作,比较模块内外变量个数变化情况  
  13. print('num-of-trainable_variables=',len(tf.trainable_variables()),' num-of-global_variables=',len(tf.global_variables()))  
  14. print('------------------------------------')  
片段二运行结果:
num-of-trainable_variables= 2  num-of-global_variables= 2
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
num-of-trainable_variables= 1  num-of-global_variables= 1
------------------------------------
由运行结果可以看出:
a、如上所述,所建操作依附于默认图 。使用with模块,在模块中让具体的图作为默认图。
b、退出with g.as_default()模块,原始默认图立马恢复(系统用栈来进行管理)
c、类似于tf.trainable_variables()、tf.global_variables()等都只针对此刻默认图里的变量,编写时要小心。

片段三:在各图下建立会话进行计算
[python] view plain copy
  1. with g1.as_default():  
  2.     sess1 = tf.Session(graph=g1)  
  3.     print('sess1',sess1)  
  4.     init = tf.global_variables_initializer()  
  5.     sess1.run(init)  
  6.     train = tf.train.GradientDescentOptimizer(0.5).minimize(loss)  
  7.     for step in range(201):  
  8.         sess1.run(train)  
  9.         if step % 100 == 0:  
  10.             print(step, sess1.run(W), sess1.run(b))  
  11.   
  12.   
  13. with g2.as_default():  
  14.     w = tf.Variable(1.0)  
  15.     b = tf.Variable(1.5)  
  16.     wb=w+b  
  17.     sess2 = tf.Session(graph=g2)  
  18.     sess2.run(tf.global_variables_initializer())  
  19. print(sess2.run(wb))#定义并初始化后,可以在模块外运行  
  20. print('sess2',sess2)  
片段三的运行结果如下:
sess1 <tensorflow.python.client.session.Session object at 0x000001AACCA168D0>
0 [-0.41363323] [ 0.84654742]
100 [ 0.09930082] [ 0.30039138]
200 [ 0.09999922] [ 0.30000046]
2.5
sess2 <tensorflow.python.client.session.Session object at 0x000001AACC9FBE80>

由运行结果可以看出:在各自的图下建立各自的会话进行计算各不干扰

博文就到此结束,看官若有疑问,欢迎留言!


原创粉丝点击