TensorFlow学习--模型的存储与恢复

来源:互联网 发布:2016淘宝客怎么做 编辑:程序博客网 时间:2024/04/30 12:57

TensorFlow模型的存储与恢复
最简单的保存和恢复模型的方法是使用tf.train.Saver对象.

模型的存储

用tf.train.Saver创建一个Saver来存储模型中的所有变量.

#!/usr/bin/python# coding:utf-8import tensorflow as tf# 定义两个常量Variablev1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")# 变量初始化init_op = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init_op)    save_path = saver.save(sess, "model/model.ckpt")    print "Model saved in file:", save_path

输出:

Model saved in file: model/model.ckpt

可以在model目录下看到:

这里写图片描述
变量存储在二进制文件里,主要包含从变量到tensor值的映射关系.

模型的恢复

用同一个Saver对象来恢复变量.
当从文件中恢复变量时,不需要事先对变量进行初始化.

#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")# 当从文件中恢复变量时,不需要事先初始化# init_op = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess:    # sess.run(init_op)    saver.restore(sess, "model/model.ckpt")    print "Model:"    print v1.eval()    print v2.eval()

输出:

Model:[ 1.  1.  1.][ 2.  2.  2.  2.  2.]

指定变量存储与恢复

如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量.
通过给tf.train.Saver()传入python字典或列表,来保持变量及其对应的名称:键对应使用的名称,值对应被管理的变量.

传入字典

存储

#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()# 如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量saver = tf.train.Saver({"variable_v1":v1})with tf.Session() as sess:    sess.run(init_op)    save_path = saver.save(sess, "model/model_v1.ckpt")    print "Model saved in file:", save_path

输出:

Model saved in file: model/model_v1.ckpt

可以在model目录下看到:

这里写图片描述

恢复

#!/usr/bin/python# coding:utf-8import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")saver = tf.train.Saver({"variable_v1":v1})with tf.Session() as sess:    # sess.run(init_op)    saver.restore(sess, "model/model_v1.ckpt")    print "Model v1:"    print v1.eval()    # 或使用sess.run(v1)    # print sess.run(v1)

输出:

Model v1:[ 1.  1.  1.]

传入列表

存储

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()saver = tf.train.Saver([v1, v2])with tf.Session() as sess:    sess.run(init_op)    saver.save(sess, "model/model_v1v2.ckpt")

恢复

import tensorflow as tfv1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")saver = tf.train.Saver([v1])with tf.Session() as sess:    saver.restore(sess, "model/model_v1v2.ckpt")    print sess.run(v1)

输出:

[ 1.  1.  1.]

创建多个saver对象

需要保存和恢复变量的不同子集时可以创建任意多个saver对象.

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")init_op = tf.initialize_all_variables()saver1 = tf.train.Saver({"variable_v1":v1})saver2 = tf.train.Saver({"variable_v2":v2})with tf.Session() as sess:    sess.run(init_op)    saver1.save(sess, "model/model_v1.ckpt")    saver2.save(sess, "model/model_v2.ckpt")

可以在model目录下看到:

这里写图片描述

同一个变量也可被列入多个saver对象中,只有saver的restore()函数被运行时它的值才会被改变.

原创粉丝点击