tensorflow variable的保存和加载

来源:互联网 发布:java框架电子书 编辑:程序博客网 时间:2024/06/05 02:31

tensorflow提了供tf.train.saver类已完成variable的保存和加载。其中save方法可以用来将计算图中的variable全部或者部分存储到ckpt文件,restore方法可以将ckpt文件中的全部或者部分变量导入计算图中。按照官方定义ckpt文件的作用是: map variable names to tensor values

variable存储和加载的一组实验:

saver_test.py

import tensorflow as tf# Create some variables.v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)dec_v3 = v3.assign(v3-2)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess:  sess.run(init_op)  # Do some work with the model.  inc_v1.op.run()  dec_v2.op.run()  dec_v3.op.run()  print sess.run(v1)  print sess.run(v2)  print sess.run(v3)  # Save the variables to disk.^  save_path = saver.save(sess, "./saved_model/model.ckpt")  print("Model saved in file: %s" % save_path)

restore_test.py

import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3 = tf.get_variable('v3_x',shape=[4])# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, "./saved_model/model.ckpt")  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())  print("v3 : %s" % v3.eval())

结果:

报错:Key v3_x not found in checkpoint.

修改后的restore_test.py

import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)# Add ops to save and restore all the variables.saver = tf.train.Saver({'v1':v1,'v2':v2})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  v3_x.initializer.run()  # Restore variables from disk.  saver.restore(sess, "./saved_model/model.ckpt")  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())  print("v3 : %s" % v3_x.eval())

结果:

Model restored.
v1 : [ 1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
v3 : [ 0. 0. 0. 0.]

v3_x未做初始化的 restore_test.py

import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)# Add ops to save and restore all the variables.saver = tf.train.Saver({'v1':v1,'v2':v2})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # v3_x.initializer.run()  # Restore variables from disk.  saver.restore(sess, "./saved_model/model.ckpt")  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())  print("v3 : %s" % v3_x.eval())

结果:

输出:报错 Attempting to use uninitialized value v3_x.

结果分析:

  1. 变量的恢复按照名字匹配。
  2. 如果restore部分变量需要在saver中指明,否则默认或restore ckpt中所有的变量,如果不存在则报错。
  3. saver restore恢复的变量不需要初始化,其他变量使用之前需要初始化。
  4. saver_test.py中的graph仅仅定义了variable,其他的op没有定义,但是仍然可以恢复,因此saver恢复时仅恢复graph中的varible,不关心graph的结构是否变化。在一些应用如:导出pb freeze网络结构时,需要将graph的输入从batchsize个改成1个,此时可以修改graph中数据的输入维度,不修改graph中的权重参数,将修改后的graph重新save。

加载 graph中的部分变量

有时候在一些迁移学习的任务中我们可能只需要restore部分权重到新的网络作为初始化参数进行fine-tune,tensorflow支持restore sub-graph操作。
saver的构造函数支持指定一个变量列表,在restore的时候仅仅restore列表中的变量。通过 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name_scope_of_variable) 得到特定name_scope下的variable_list,将该list放入train.saver的初始化方法构建saver对象
例如:
有一个预训练的模型InceptionV1,如果想只训练fc层的话可以:

  1. 训练一个全新的InceptionV1模型,最后fc的scope设置为”last”
  2. restore 训练好的InceptionV1:saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=’InceptionV1’))
    saver1.restore(session, ‘inception_model_from_google.ckpt’)
  3. tf.train.Optimizer(0.0001).minimize(
    loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=’Inceptionretrained’))

如果需要将ckpt文件中的变量映射到新graph中新的变量名:

import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)dec_v3_x = v3_x.assign(v3_x-2)# Restore v3 in ckpt to v3_x in new graph.saver = tf.train.Saver({'v1':v1,'v2':v2,'v3':v3_x})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, "./saved_model/model.ckpt")  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())  print("v3 : %s" % v3_x.eval())

输出

Model restored.
v1 : [ 1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
v3 : [-2. -2. -2. -2.]

在构建网络时合理利用namescope/variablescope会为模型的复用打下一个良好的基础。

将ckpt中的变量进行重新命名

有时候我们对graph进行了修改,改变了部分变量的名字,这样会导致以前保存的ckpt文件无法使用,这时我们可以将ckpt文件中的变量名进行修改,这样就可以继续使用以前的模型了。

import tensorflow as tfcheckpoint_dir='./saved_model/model.ckpt'checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)with tf.Session() as sess:    for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir):        # Load the variable        var = tf.contrib.framework.load_variable(checkpoint_dir, var_name)        print var_name        print var        new_name=var_name+'/test'        # Rename the variable and place the new variable into default graph        var = tf.Variable(var, name=new_name)        # Save the variables in default graph        saver = tf.train.Saver()        sess.run(tf.global_variables_initializer())        saver.save(sess, checkpoint.model_checkpoint_path)

参考资料:
1. https://www.tensorflow.org/programmers_guide/saved_model