初试tensorflow

来源:互联网 发布:剑指offer java版 编辑:程序博客网 时间:2024/06/05 00:39

最近刚接触tensorflow,入门什么的也很快,什么计算框架,计算图也有所了解。但对于自己写一个神经网络,尤其需要搞清楚一个概念:命名空间 tf.variable_scope。这个关系到写一个复杂网络时,梯度的更新,以及权值的调用。记录如下:

1/初始化一个变量有两种方式: tf.Variable() 和 tf.get_variable() ,主要区别是:

     tf.get_variable() 创建的变量名不受 name_scope 的影响;

     tf.get_variable() 创建的变量,name 属性值不可以相同;

     tf.Variable() 创建变量时,name 属性值允许重复(底层实现时,会自动引入别名制);

其具体意义是tf.get_variable() 创建一个变量时,会先自动查找所有变量,看将创建的变量名name是否已经被创建,如果已经被创建,那么将会报错。而tf.Variable()则不会,它采取的机制是自动引入一个别名机制,如下:

import tensorflow as tfh = tf.Variable(2.0, name='h')w = tf.Variable(2.0, name='h')tf.trainable_variables()[<tf.Variable 'h:0' shape=() dtype=float32_ref>, <tf.Variable 'h_1:0' shape=() dtype=float32_ref>]
自动别名为:h_1
但如果遇到一种情况,我们不想进行别名机制,而是变量共享,应该如何处理。这时我们就要用到tf.get_variable()创建变量,并且要利用tf.variable_scope这个管理空间,实现如下:

import tensorflow as tfwith tf.variable_scope("discriminator") as scope:        weights = tf.get_variable("weights", [1],initializer=tf.random_normal_initializer())        scope.reuse_variables()        h = tf.get_variable("weights", [1],initializer=tf.random_normal_initializer())tf.trainable_variables()[<tf.Variable 'discriminator/weights:0' shape=(1,) dtype=float32_ref>]
我们可以看到虽然创建了两次变量,但是由于变量名相同,第一次变量共享给了第二次创建的变量,最终还是只有一个变量。前面我们说tf.get_variabls()遇到变量名相同便会报错,但这不但没有报错,反而还共享。其中主要原因是 scope.reuse_variables()这个函数起了作用,因此想要变量共享,只需要在共享之前加入这个函数即可。
另外最重要的一点是tf.variable_scope(),我们创建的变量名是“weights”,但是实际得到的是“discriminator/weight”,这里就像一个文件夹的路径 一样,tf.variable_scope()的作用就是如此,创建一个类似“discriminator”的空间,在这个空间之下创建的所有变量都将带上这个变量名,就像是一个归类,便于管理。
其中需要说明的是tf.trainable_variables()的作用是 获得当前代码中创建的所有变量。

2/模型的保存和读取。模型保存函数很简单:

saver= tf.train.Saver(max_to_keep=3)saver.save(sess,'path/model.ckpt')
其中max_to_keep参数是指保存最近模型个数的最大数量。模型读取也很简单:
saver.restore(sess,'path/model.ckpt')
但是这种方式模型的读取,需要事先写好被读取的网络模型结构。
另一种读取是,先读取保存模型的grah图,再读取其中的变量,不需要 写好原网络模型的结构,因为这种结构图已经保存在grah中了,但是这种读取的方式,只能读取模型中的参数,并不能对原网络进行接力训练。读取方式如下:

import tensorflow as tfnew_saver = tf.train.import_meta_graph('C:\\Users\\sk\\Desktop\\D\\model\\model.ckpt.meta')sess = tf.Session() new_saver.restore(sess, 'C:\\Users\\sk\\Desktop\\D\\model\\model.ckpt')

注意第一个读取的文件是.meta指的是grah,第二个restore的是其中的参数权值。这样得到的权值参数并不能直接用于计算。我们可以对其进行提取,复制给新的网络:
d_pre_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
从该图中读取含有命名为‘generator’的权值并存入d_pre_params.

这里设置 新网络对应的权值‘new_generator’存于d_param,则copy权值方式如下:

d_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='new_generator')pre_weightD = sess.run(d_pre_params) #得到真实的值,上面只是建立一个运算模块for i, v in enumerate(d_params):     #d_params是一个列表,i得到标号,v是该标号中的权值矩阵    sess.run(v.assign(pre_weightD[i]))  # assign的作用是将pre_weightD[i]赋值给v
这就完成了权值的copy。

3/权值参数的几种获取方式

t_vars = tf.trainable_variables() #获得所有变量G_vars=[var for var in t_vars if 'generator' in var.name] #从所有变量中获得含有'generator'变量名的变量

上面2中的tf.get_collection(),以及上面这种获取的变量都是含有该变量名,就会被获取。若想获得准确的某一个变量可以用:

print(sess.run(tf.get_default_graph().get_tensor_by_name('discriminator/d_h0_conv/w:0')))
上面print是将其打印出来了,其主要函数还是:tf.get_default_graph().get_tensor_by_name().

总结:tensorflow是一个模块的化计算的方式,先定义过程,再进行初始化计算。以上大部分函数都是一种运算,得到 只是一个变量模块,而要获得具体的数值,都需要运行:

init = tf.global_variables_initializer() #初始化变量sess = tf.Session()sess.run(init)







原创粉丝点击