tensorflow variable的保存和加载
来源:互联网 发布:java框架电子书 编辑:程序博客网 时间:2024/06/05 02:31
tensorflow提了供tf.train.saver类已完成variable的保存和加载。其中save方法可以用来将计算图中的variable全部或者部分存储到ckpt文件,restore方法可以将ckpt文件中的全部或者部分变量导入计算图中。按照官方定义ckpt文件的作用是: map variable names to tensor values 。
variable存储和加载的一组实验:
saver_test.py
import tensorflow as tf# Create some variables.v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)dec_v3 = v3.assign(v3-2)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() dec_v3.op.run() print sess.run(v1) print sess.run(v2) print sess.run(v3) # Save the variables to disk.^ save_path = saver.save(sess, "./saved_model/model.ckpt") print("Model saved in file: %s" % save_path)
restore_test.py
import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3 = tf.get_variable('v3_x',shape=[4])# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "./saved_model/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3.eval())
结果:
报错:Key v3_x not found in checkpoint.
修改后的restore_test.py
import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)# Add ops to save and restore all the variables.saver = tf.train.Saver({'v1':v1,'v2':v2})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: v3_x.initializer.run() # Restore variables from disk. saver.restore(sess, "./saved_model/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3_x.eval())
结果:
Model restored.
v1 : [ 1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
v3 : [ 0. 0. 0. 0.]
v3_x未做初始化的 restore_test.py
import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)# Add ops to save and restore all the variables.saver = tf.train.Saver({'v1':v1,'v2':v2})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: # v3_x.initializer.run() # Restore variables from disk. saver.restore(sess, "./saved_model/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3_x.eval())
结果:
输出:报错 Attempting to use uninitialized value v3_x.
结果分析:
- 变量的恢复按照名字匹配。
- 如果restore部分变量需要在saver中指明,否则默认或restore ckpt中所有的变量,如果不存在则报错。
- saver restore恢复的变量不需要初始化,其他变量使用之前需要初始化。
- saver_test.py中的graph仅仅定义了variable,其他的op没有定义,但是仍然可以恢复,因此saver恢复时仅恢复graph中的varible,不关心graph的结构是否变化。在一些应用如:导出pb freeze网络结构时,需要将graph的输入从batchsize个改成1个,此时可以修改graph中数据的输入维度,不修改graph中的权重参数,将修改后的graph重新save。
加载 graph中的部分变量
有时候在一些迁移学习的任务中我们可能只需要restore部分权重到新的网络作为初始化参数进行fine-tune,tensorflow支持restore sub-graph操作。
saver的构造函数支持指定一个变量列表,在restore的时候仅仅restore列表中的变量。通过 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=name_scope_of_variable) 得到特定name_scope下的variable_list,将该list放入train.saver的初始化方法构建saver对象
例如:
有一个预训练的模型InceptionV1,如果想只训练fc层的话可以:
- 训练一个全新的InceptionV1模型,最后fc的scope设置为”last”
- restore 训练好的InceptionV1:saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=’InceptionV1’))
saver1.restore(session, ‘inception_model_from_google.ckpt’)- tf.train.Optimizer(0.0001).minimize(
loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=’Inceptionretrained’))
如果需要将ckpt文件中的变量映射到新graph中新的变量名:
import tensorflow as tftf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])v3_x = tf.get_variable('v3_x',shape=[4],initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)dec_v3_x = v3_x.assign(v3_x-2)# Restore v3 in ckpt to v3_x in new graph.saver = tf.train.Saver({'v1':v1,'v2':v2,'v3':v3_x})# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, "./saved_model/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) print("v3 : %s" % v3_x.eval())
输出
Model restored.
v1 : [ 1. 1. 1.]
v2 : [-1. -1. -1. -1. -1.]
v3 : [-2. -2. -2. -2.]
在构建网络时合理利用namescope/variablescope会为模型的复用打下一个良好的基础。
将ckpt中的变量进行重新命名
有时候我们对graph进行了修改,改变了部分变量的名字,这样会导致以前保存的ckpt文件无法使用,这时我们可以将ckpt文件中的变量名进行修改,这样就可以继续使用以前的模型了。
import tensorflow as tfcheckpoint_dir='./saved_model/model.ckpt'checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)with tf.Session() as sess: for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): # Load the variable var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) print var_name print var new_name=var_name+'/test' # Rename the variable and place the new variable into default graph var = tf.Variable(var, name=new_name) # Save the variables in default graph saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.save(sess, checkpoint.model_checkpoint_path)
参考资料:
1. https://www.tensorflow.org/programmers_guide/saved_model
- tensorflow variable的保存和加载
- tensorflow中变量的保存和加载
- tensorflow保存 和 加载模型
- Tensorflow 保存和加载模型
- tensorflow保存和加载模型
- tensorflow中Variable和Placeholder的学习
- tensorflow 保存和加载模型 -2
- tensorflow模型参数保存和加载问题
- TensorFlow保存和加载训练模型
- tensorflow-模型保存和加载(二)
- TensorFlow保存和加载训练模型
- 使用tensorflow保存、加载和使用模型
- TensorFlow保存和加载训练模型
- tensorflow的简单使用、保存、加载
- tensorflow的基本用法(十)——保存神经网络参数和加载神经网络参数
- TensorFlow模型参数的保存和加载(含演示代码)
- TensorFlow模型op的保存和加载(含演示代码)
- Tensorflow中关于Tensor和Variable的理解
- 获取元素的偏移量offset
- Linux入门4
- 爬虫技术(05)神箭手爬虫回调函数
- sqlServer存储过程查询语句
- Java实现二叉树,以及先序、中序、后序遍历算法的实现
- tensorflow variable的保存和加载
- mysql性能优化
- Web项目中 .classpath、.mymetadata、.project文件的作用
- Banner 图片无限轮播
- 《Machine Learning》第二讲 线性回归与梯度下降
- 遍历文件夹
- Java对List的常用操作
- Linux下安装及简单使用nmap
- Redis 讲解系列之 与Spring集成(二)