[Tensorflow]Sharing Variables 共享权值【tf.get_variable 和 tf.variable_scope】

来源:互联网 发布:成都cnc编程最新招聘 编辑:程序博客网 时间:2024/05/19 13:55

参考Sharing Variables 

一、tf.get_variable

之前就觉得 tf.Variable(tf.random_normal(xxx))这类写法太丑了,果然Tensorflow 提供了更加一体化的API。


tf.get_variable(  name,               #以后老老实实每个变量取个名吧。restore也方便。  shape=None,         #shape,[None,28,28,1]  dtype=None,         #如tf.float32   initializer=None,   #改进了tf.Vairable蹩脚的写法。  regularizer=None,   #用于L1/L2正则化  trainable=True,     #If True also add the variable to the graph collection GraphKeys.  collections=None,   #默认为 [GraphKeys.GLOBAL_VARIABLES],即 ["varibles"],包含collection名的列表。  caching_device=None,  partitioner=None,  validate_shape=True,  custom_getter=None)  """参数:initializer  (1)默认值None,即使用uniform_unit_scaling_initializer。    (文档看不太懂,猜测是均匀分布获取参数W,且对于输入x,使得y=x*W中y的scale intact)  (2)Tensor,那么会复制此Tensor  (3)常数:        tf.constant_initializer(value=0, dtype=tf.float32)  (4)正太分布:    tf.random_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)  (5)截断正太分布: tf.truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None, dtype=tf.float32)参数:regularizer  regularizer: A (Tensor -> Tensor or None) function;   the result of applying it on a newly created variable will be added to the collection GraphKeys.REGULARIZATION_LOSSES and can be used for regularization."""
例子:
import tensorflow as tfsess=tf.Session()a=tf.get_variable("a",[3,3,32,64],initializer=tf.random_normal_initializer())b=tf.get_variable("b",[64],initializer=tf.random_normal_initializer())gv= tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)for var in gv:   print(var is a)  print(var.get_shape())


二、tf.variable_scope 和 共享变量

tf.variable_scope(   name,                 #variable的namespace   reuse=False,          #False:新建Tensor,重名会产生异常;True:重用Tensor,不存在会产生异常。   regularizer=None      #正则化   #其他参数略去)
tf.variable_scope其实就是对在其内定义的variable设置namespace + 用于变量共享

例子:
import tensorflow as tfsess=tf.Session()def run(a):  sess.run(tf.global_variables_initializer())  return sess.run(a)#**************以下函数获取scope_name命名空间下变量名为var_name的变量,不存在创建,存在则返回已存在的变量***********def get_scope_variable(scope_name,var_name,shape=None):  with tf.variable_scope(scope_name) as scope:            #reuse设置为true不存在会异常,设置为False,存在重名会异常。故我们捕获异常来判断是否存在。    try:                  var=tf.get_variable(var_name,shape)    except ValueError:      scope.reuse_variables()      var=tf.get_variable(var_name)  return var var_1 = get_scope_variable("cur_scope","my_var",[100])var_2 = get_scope_variable("cur_scope","my_var",[100])print(var_1 is var_2)print(var_1.name)                                        #此时变量名为  cur_scope/my_var


阅读全文
0 0
原创粉丝点击