tensorflow variable_scope共享变量

来源:互联网 发布:unity3d 美女 模型 编辑:程序博客网 时间:2024/06/03 17:14

参考文档:http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/variable_scope.html


tf.Variable(<variable_name>)用于创建一个新变量,可以创建相同名字变量,底层会自动引入别名机制,给新创建的变量名加数字,两个变量是不相同的
tf.get_variable(<variable_name>)获取一个变量,当变量已经存在,会报错,显示变量已经存在;不存在会创建一个新变量。无视name_scope

tf.get_variable(name, shape, dtype, initializer)

initializer初始化器:

tf.constant_initializer(value) 初始化一切所提供的值,
tf.random_uniform_initializer(a, b)从a到b均匀初始化,
tf.random_normal_initializer(mean, stddev) 用所给平均值和标准差初始化正态分布
tf.truncated_normal_initializer(mean, stddev) :截取的正态分布
tf.zeros_initializer():全部是0
tf.ones_initializer():全是1



tf.name_scope(<scope_name>)

主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。

import tensorflow as tfwith tf.name_scope("a_name_scope"):    initializer = tf.constant_initializer(value=1)    var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32, initializer=initializer)    # tf.get_variable()定义的变量不会被tf.name_scope()当中的名字所影响。    var2 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)    var21 = tf.Variable(name='var2', initial_value=[2.1], dtype=tf.float32)    var22 = tf.Variable(name='var2', initial_value=[2.2], dtype=tf.float32)with tf.Session() as sess:    sess.run(tf.initialize_all_variables())    print(var1.name)        # var1:0    print(sess.run(var1))   # [ 1.]    print(var2.name)        # a_name_scope/var2:0    print(sess.run(var2))   # [ 2.]    print(var21.name)       # a_name_scope/var2_1:0    print(sess.run(var21))  # [ 2.0999999]    print(var22.name)       # a_name_scope/var2_2:0    print(sess.run(var22))  # [ 2.20000005]




共享变量:

tf.variable_scope(<scope_name>)
管理一个图中的变量名,避免变量之间的命名冲突,允许在一个variable_scope下面共享变量。

如果想要达到重复利用变量的效果,我们就要使用tf.variable_scope(),并搭配tf.get_variable()这种方式产生和提取变量。不像tf.Variable()每次都会产生新的变量,tf.get_variable()如果遇到了同样名字的变量时,需要在该变量定义的后面强调scope.reuse_variables(),表示该变量可以重复使用,否则会报错

import tensorflow as tfwith tf.variable_scope("a_variable_scope") as scope:    initializer = tf.constant_initializer(value=3)    var3 = tf.get_variable(name='var3', shape=[1], dtype=tf.float32, initializer=initializer)    scope.reuse_variables()  # 声明前面出现的变量为共享变量    var3_reuse = tf.get_variable(name='var3', )    var4 = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32)    var4_reuse = tf.Variable(name='var4', initial_value=[4], dtype=tf.float32)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print(var3.name)  # a_variable_scope/var3:0    print(sess.run(var3))  # [ 3.]    print(var3_reuse.name)  # a_variable_scope/var3:0    print(sess.run(var3_reuse))  # [ 3.]    print(var4.name)  # a_variable_scope/var4:0    print(sess.run(var4))  # [ 4.]    print(var4_reuse.name)  # a_variable_scope/var4_1:0    print(sess.run(var4_reuse))  # [ 4.]


共享变量在class中的使用(tf.get_variable() + tf.variable_scope)

共享类中变量的两种方式:

方式一:

class A(object):    def __init__(self):        self.a = tf.get_variable('a', [1], dtype=tf.float32)        self.b = tf.assign(self.a, [10])with tf.variable_scope('AA', reuse=False):    x = A()with tf.variable_scope('AA', reuse=True):    y = A()sess = tf.Session()print(sess.run(x.b))  # [10.]print(sess.run(y.a))  # [10.]

方式二:

class A(object):    def __init__(self):        self.a = tf.get_variable('a', [1], dtype=tf.float32)        self.b = tf.assign(self.a, [10])with tf.variable_scope('AA'):    x = A()    tf.get_variable_scope().reuse_variables()    y = A()sess = tf.Session()print(sess.run(x.b))  # [10.]print(sess.run(y.a))  # [10.]



原创粉丝点击