Notes on tensorflow(六)variable_scope

来源:互联网 发布:学美工设计培训班 编辑:程序博客网 时间:2024/05/22 02:06

https://www.tensorflow.org/programmers_guide/variable_scope

Scope与Share机制

tensorflow 引入了namespace机制, 也就是scope, 可以方便地命名、共享变量. 当需要共享变量时, 创建变量使用tf.get_variable方法而不是tf.Variable.

import tensorflow as tfwith tf.variable_scope('foo'):    v1 = tf.get_variable('v1', [1])    print v1.namewith tf.variable_scope('foo', reuse = True):    v2 = tf.get_variable('v1')    #v3 = tf.get_variable('v3', [3]) 会报错    print v2.nameassert v2 is v1
foo/v1:0foo/v1:0
  • 一个scope对应一个namespace,当在scope里创建任意有name的东西时, 它的name为: scope_name/var_name

  • reuse = True不可少。它是variable_scope的一个属性, 直接决定如何创建变量。

    • reuse = False时, 先检查是否已经存在相同name的Variable, 如果有, 报错。然后以对应name创建一个新的Variable
    • reuse = True时,不会创建新的Variable。直接查找自己name对应的variable, 如果没有, 则报错。
  • reuse属性可继承:在reuse = True的scope里创建子scope时, 子scope的reuse==True

import tensorflow as tfwith tf.variable_scope('foo', reuse = True) as foo:    print foo.reuse    with tf.variable_scope('doo') as doo:        print doo.reuse
TrueTrue

variable_scope与name_scope

它们在效果上的区别是variable_scope会影响它内部创建的所有有name属性的节点, 但name_scope只影响Operator节点的命名。 用处之一是在多gpu训练时在不同的device上, 使用相同的variable_scope, 但使用不同的namescope

import tensorflow as tfwith tf.variable_scope('foo'):    with tf.name_scope('ns'):        a = tf.get_variable('a', [1])        b = a + 1;        print a.name        print b.name        print b.op.name
foo/a:0foo_2/ns/add:0foo_2/ns/add

为variable_scope指定默认的initializer

为variable_scope指定默认initializer的好处是不用在每次调用创建变量的方法时传入初始值了。它也是可以继承的。

import tensorflow as tfdef show(v):    with tf.Session() as sess:        init = tf.global_variables_initializer()        sess.run(init)        print v.eval()with tf.variable_scope('foo', initializer = tf.constant_initializer(0.2)):    cv1 = tf.get_variable('cv1', [1])    show(cv1)    with tf.variable_scope('sub_foo'):        cv2 = tf.get_variable('cv2', [1])        show(cv2)        with tf.variable_scope('sub_sub_foo', initializer = tf.constant_initializer(0.1)):            cv3 = tf.get_variable('cv3', [1])            show(cv3)
[ 0.2][ 0.2][ 0.1]
0 0
原创粉丝点击