TensorFlow学习--模型的存储与恢复
来源:互联网 发布:2016淘宝客怎么做 编辑:程序博客网 时间:2024/04/30 12:57
TensorFlow模型的存储与恢复
最简单的保存和恢复模型的方法是使用tf.train.Saver对象.
模型的存储
用tf.train.Saver创建一个Saver来存储模型中的所有变量.
#!/usr/bin/python# coding:utf-8import tensorflow as tf# 定义两个常量Variablev1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")# 变量初始化init_op = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess: sess.run(init_op) save_path = saver.save(sess, "model/model.ckpt") print "Model saved in file:", save_path
输出:
Model saved in file: model/model.ckpt
可以在model目录下看到:
变量存储在二进制文件里,主要包含从变量到tensor值的映射关系.
模型的恢复
用同一个Saver对象来恢复变量.
当从文件中恢复变量时,不需要事先对变量进行初始化.
#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")# 当从文件中恢复变量时,不需要事先初始化# init_op = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess: # sess.run(init_op) saver.restore(sess, "model/model.ckpt") print "Model:" print v1.eval() print v2.eval()
输出:
Model:[ 1. 1. 1.][ 2. 2. 2. 2. 2.]
指定变量存储与恢复
如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量.
通过给tf.train.Saver()传入python字典或列表,来保持变量及其对应的名称:键对应使用的名称,值对应被管理的变量.
传入字典
存储
#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()# 如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量saver = tf.train.Saver({"variable_v1":v1})with tf.Session() as sess: sess.run(init_op) save_path = saver.save(sess, "model/model_v1.ckpt") print "Model saved in file:", save_path
输出:
Model saved in file: model/model_v1.ckpt
可以在model目录下看到:
恢复
#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")saver = tf.train.Saver({"variable_v1":v1})with tf.Session() as sess: # sess.run(init_op) saver.restore(sess, "model/model_v1.ckpt") print "Model v1:" print v1.eval() # 或使用sess.run(v1) # print sess.run(v1)
输出:
Model v1:[ 1. 1. 1.]
传入列表
存储
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()saver = tf.train.Saver([v1, v2])with tf.Session() as sess: sess.run(init_op) saver.save(sess, "model/model_v1v2.ckpt")
恢复
import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")saver = tf.train.Saver([v1])with tf.Session() as sess: saver.restore(sess, "model/model_v1v2.ckpt") print sess.run(v1)
输出:
[ 1. 1. 1.]
创建多个saver对象
需要保存和恢复变量的不同子集时可以创建任意多个saver对象.
import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()saver1 = tf.train.Saver({"variable_v1":v1})saver2 = tf.train.Saver({"variable_v2":v2})with tf.Session() as sess: sess.run(init_op) saver1.save(sess, "model/model_v1.ckpt") saver2.save(sess, "model/model_v2.ckpt")
可以在model目录下看到:
同一个变量也可被列入多个saver对象中,只有saver的restore()函数被运行时它的值才会被改变.
阅读全文
1 0
- TensorFlow学习--模型的存储与恢复
- tensorflow 1.0 学习:模型的保存与恢复(Saver)
- 【深度学习】Tensorflow模型保存与恢复
- tensorflow 模型的保存与恢复(Saver)
- Tensorflow学习(6)模型的保存与恢复(saver)
- 16、TensorFLow 模型参数的保存与恢复
- 5.1 Tensorflow:图与模型的加载与存储
- Tensorflow模型持久化与恢复
- TensorFlow下网络模型的存储与加载
- tensorflow 变量简单存储与恢复
- tensorflow中模型的保存和恢复
- Tensorflow 部分恢复模型
- Tensorflow:模型保存/模型恢复?
- tensorflow之inception_v3模型的部分加载及权重的部分恢复(23)---《深度学习》
- Tensorflow学习笔记:模型训练数据的保存和恢复的简单实例
- TensorFlow深度学习笔记 文本与序列的深度模型
- tensorflow 学习笔记10 网络模型的保存与提取
- 5.2 TensorFlow:模型的加载,存储,实例
- CCPC.2017B.K-th Number
- 各种总结
- android retrofit post
- 看看这五种博客推广方法的具体说明
- Linux中重要目录及重要命令
- TensorFlow学习--模型的存储与恢复
- MySQL数据库的一些常用命令
- UI基础编程
- Netty 快速入门
- <转载>java(25):Spring框架简介,总结的很好!
- 动态链编与静态链编
- jQuery之内部插入易混淆方法解析
- Codeforces Round #439 (Div. 2) 869 C. The Intriguing Obsession
- MVP+recyclerview网络请求列表数据