tf.train.NewCheckpointReader实现保存变量的提取

来源:互联网 发布:阿里云速度 编辑:程序博客网 时间:2024/05/18 22:14

tf.train.NewCheckpointReader('path'):path是保存的路径,这个函数可以得到保存的所有变量

例如:

先保存一个模型,参数为v,v1.
import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  v = tf.Variable(0, dtype=tf.float32, name='v')v1 = tf.Variable(0, dtype=tf.float32, name='v1')result = v + v1x = tf.placeholder(tf.float32, shape=[1], name='x')test = result + xinit = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)saver.save(sess, "/home/penglu/Desktop/lp/model.ckpt") 
利用tf.train.NewCheckpointReader导出所有变量

import tensorflow as tf;  import numpy as np;  import matplotlib.pyplot as plt;  reader = tf.train.NewCheckpointReader("/home/penglu/Desktop/lp/model.ckpt")variables = reader.get_variable_to_shape_map()for ele in variables:print ele
输出:

v1
v

原创粉丝点击