关于TensorFlow中的多图(Multiple Graphs)

来源:互联网 发布:win10禁止自动更新软件 编辑:程序博客网 时间:2024/06/15 02:38

一、摘要

   TensorFlow中的图(Graph)是众多操作(Ops)的集合,它描述了具体的操作类型与各操作之间的关联。在实际应用中,我们可以直接把图理解为神经网络(Neural Network)结构的程序化描述。TensorFlow中的会话(Session)则实现图中所有操作,使得数据(Tensor类型)在图中流动(Flow)起来。平常学习或使用TensorFlow中,我们基本是构筑一个图,开启一个会话,然后run。但最近由于工作需要,我探索了多图(Multiple Graphs)方式。本文主要任务是简单记录学习过程,以备复阅。如果能给看官提供一丝帮助,那也是极好。

二、多图实现

(交代一下平台版本:PyCharm Community Edition 2016.3.2、Python3.5.2、tensorflow0.12.1)
片段一:多图的建立与默认图

import tensorflow as tfimport numpy as npg1 = tf.Graph()  #建立图1g2 = tf.Graph()  #建立图2print('tf.get_default_graph()=',tf.get_default_graph())#获取默认图,并显示基本信息print('g1                    =',g1)print('g2                    =',g2)print('------------------------------------')
片段一运行结果:
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
g2              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FD68>
------------------------------------
从结果地址可以看出:默认图自动存在,手动建立的图与默认图完全不同


片段二:建立各图下的具体操作(Op)
with g1.as_default():#在with模块中,g1作为默认图    x_data = np.random.rand(100).astype(np.float32)#定义图中具体操作    y_data = x_data * 0.1 + 0.3    W = tf.Variable(tf.random_uniform([1], -1.0, 1.0))    b = tf.Variable(tf.zeros([1]))    y = W * x_data + b    loss = tf.reduce_mean(tf.square(y - y_data))    print('num-of-trainable_variables=', len(tf.trainable_variables()), ' num-of-global_variables=',len(tf.global_variables()))#统计变量个数    print('g1                    =',g1)    print('tf.get_default_graph()=',tf.get_default_graph())print('tf.get_default_graph()=',tf.get_default_graph())W2 = tf.Variable(tf.random_uniform([1], -1.0, 1.0))#with模块外定义操作,比较模块内外变量个数变化情况print('num-of-trainable_variables=',len(tf.trainable_variables()),' num-of-global_variables=',len(tf.global_variables()))print('------------------------------------')
片段二运行结果:
num-of-trainable_variables= 2  num-of-global_variables= 2
g1              = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AAC7C5FC50>
tf.get_default_graph() = <tensorflow.python.framework.ops.Graph object at 0x000001AACC965550>
num-of-trainable_variables= 1  num-of-global_variables= 1
------------------------------------
由运行结果可以看出:
a、如上所述,所建操作依附于默认图 。使用with模块,在模块中让具体的图作为默认图。
b、退出with g.as_default()模块,原始默认图立马恢复(系统用栈来进行管理)
c、类似于tf.trainable_variables()、tf.global_variables()等都只针对此刻默认图里的变量,编写时要小心。

片段三:在各图下建立会话进行计算
with g1.as_default():    sess1 = tf.Session(graph=g1)    print('sess1',sess1)    init = tf.global_variables_initializer()    sess1.run(init)    train = tf.train.GradientDescentOptimizer(0.5).minimize(loss)    for step in range(201):        sess1.run(train)        if step % 100 == 0:            print(step, sess1.run(W), sess1.run(b))with g2.as_default():    w = tf.Variable(1.0)    b = tf.Variable(1.5)    wb=w+b    sess2 = tf.Session(graph=g2)    sess2.run(tf.global_variables_initializer())print(sess2.run(wb))#定义并初始化后,可以在模块外运行print('sess2',sess2)
片段三的运行结果如下:
sess1 <tensorflow.python.client.session.Session object at 0x000001AACCA168D0>
0 [-0.41363323] [ 0.84654742]
100 [ 0.09930082] [ 0.30039138]
200 [ 0.09999922] [ 0.30000046]
2.5
sess2 <tensorflow.python.client.session.Session object at 0x000001AACC9FBE80>

由运行结果可以看出:在各自的图下建立各自的会话进行计算各不干扰

博文就到此结束,看官若有疑问,欢迎留言!

2 0
原创粉丝点击