variables_to_restore函数的用法

来源:互联网 发布:cad看图软件 mac 免费 编辑:程序博客网 时间:2024/05/22 01:32

variables_to_restore是为了在保持模型的时候方便使用滑动平均的参数,如果不使用这个保存,那模型就会保存所以参数,除非你提前设定,就是在保存的时候指定保存变量也是可以的,比如saver = tf.train.Saver([v])这样就可以指定保存变量v,在模型导入的时候只有这个变量会被导入。

比如:

import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')ema = tf.train.ExponentialMovingAverage(0.99)maintain_average_op = ema.apply(tf.all_variables())saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.initialize_all_variables())sess.run(tf.assign(v, 10.0))sess.run(maintain_average_op)saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')
模型导入:

import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')ema = tf.train.ExponentialMovingAverage(0.99)maintain_average_op = ema.apply(tf.all_variables())saver = tf.train.Saver()with tf.Session() as sess:# sess.run(tf.initialize_all_variables())# sess.run(tf.assign(v, 10.0))# sess.run(maintain_average_op)# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')print sess.run(ema.average(v))print sess.run(v)
输出:

0.0999999
10.0

这样不是很方便,因为我再次导入模型,变量v的值我不用,并且想要用计算后的值替代v,这样在模型被导入就方便就算

下面代码显示如何使用:

import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')ema = tf.train.ExponentialMovingAverage(0.99)maintain_average_op = ema.apply(tf.all_variables())saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.initialize_all_variables())sess.run(tf.assign(v, 10.0))sess.run(maintain_average_op)saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')print sess.run(v)print sess.run(ema.average(v))# saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')# print sess.run(v)
输出:

10.0
0.0999999


导入模型的时候tf.train.Saver函数要变化一下,变为tf.train.Saver(ema.variables_to_restore()),代码如下:

import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  v = tf.Variable(tf.constant(0.0, dtype=tf.float32), name='v')ema = tf.train.ExponentialMovingAverage(0.99)maintain_average_op = ema.apply(tf.all_variables())saver = tf.train.Saver(ema.variables_to_restore())with tf.Session() as sess:# sess.run(tf.initialize_all_variables())# sess.run(tf.assign(v, 10.0))# sess.run(maintain_average_op)# saver.save(sess, '/home/penglu/Desktop/lp/model.ckpt')# print sess.run(v)# print sess.run(ema.average(v))saver.restore(sess, '/home/penglu/Desktop/lp/model.ckpt')print sess.run(v)
输出:

0.0999999


注意:如果不变的话,那么输出就会是10!




原创粉丝点击