输出TensorFlow中checkpoint内变量的几种方法
来源:互联网 发布:西安交通网络教育大学 编辑:程序博客网 时间:2024/05/22 07:48
在上一篇关于MDM模型的文章中,作者给出的是基于TensorFlow的实现。由于一些原因,需要将在TF上训练好的模型转换为Caffe,经过一番简化,现在的要需求是只要将TF保存在checkpoint中的变量值输出到txt或npy中即可。这里列了几种简单的可行的方法.
1,最简单的方法,是在有model 的情况下,直接用tf.train.saver进行restore,就像 cifar10_eval.py 中那样。然后,在sess中直接run变量的名字就可以得到变量保存的值。
在这里以cifar10_eval.py为例。首先,在Graph中穿件model。
with tf.Graph().as_default() as g: images, labels = cifar10.inputs(eval_data=eval_data) logits = cifar10.inference(images) top_k_op = tf.nn.in_top_k(logits, labels, 1)
然后,通过tf.train.ExponentialMovingAverage.variable_to_restore确定需要restore的变量,默认情况下是model中所有trainable变量的movingaverge名字。并建立saver 对象
variable_averages = tf.train.ExponentialMovingAverage( cifar10.MOVING_AVERAGE_DECAY) variables_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore)
variables_to_restore中是变量的movingaverage名字到变量的mapping(就是个字典)。我们可以打印尝试打印里面的变量名,
for name in variables_to_restore: print(name)输出结果为
softmax_linear/biases/ExponentialMovingAverageconv2/biases/ExponentialMovingAveragelocal4/biases/ExponentialMovingAveragelocal3/biases/ExponentialMovingAveragesoftmax_linear/weights/ExponentialMovingAverageconv1/biases/ExponentialMovingAveragelocal4/weights/ExponentialMovingAveragelocal3/weights/ExponentialMovingAverageconv2/weights/ExponentialMovingAverageconv1/weights/ExponentialMovingAverage
然后在中通过run 变量名的方式就可以得到保存在checkpoint中的值,引文sess.run方法得到的是numpy形式的数据,就可以通过np.save或np.savetxt来保存了。
with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # Restores from checkpoint saver.restore(sess, ckpt.model_checkpoint_path) conv1_w=sess.run('conv1/weights/ExponentialMovingAverage')此时conv1_w就是conv1/weights的MovingAverage的值,并且是numpy array的形式。
这种方法不需要model,只要有checkpoint文件就行。
首先用tf.train.NewCheckpointReader读取checkpoint文件
<span style="font-size:14px;">reader = tf.train.NewCheckpointReader(file_name)</span>如果没有指定需要输出的变量,怎全部输出,如果指定了,则可以输出相应的变量
<span style="font-size:14px;">if not tensor_name: print(reader.debug_string().decode("utf-8"))else: print("tensor_name: ", tensor_name) print(reader.get_tensor(tensor_name))</span>可以根据自己的需要进行操作。
3,第三种方法也是TF官方在tool里面给的,称为freeze_graph, 在官方的这个tutorials中有介绍。
一般情况下TF在训练过程中会保存两种文件,一种是保存了变量值的checkpoint文件,另一种是保存了模型的Graph(GraphDef)等其他信息的MetaDef文件,
以.meta结尾Meta,但是其中没有保存变量的值。freeze_graph.py的主要功能就是将chenkpoint中的变量值保存到模型的GraphDef中,使得在一个文件中既
包含了模型的Graph,又有各个变量的值,便于后续操作。当然变量值的保存是可以有选择性的。
在freeze_graph.py中,首先是导入GraphDef (如果有GraphDef则可之间导入,如果没有,则可以从MetaDef中导入). 然后是从GraphDef中的所有nodes中
抽取主模型的nodes(比如各个变量,激活层等)。再用saver从checkpoint中恢复变量的值,以constant的形式保存到抽取的Grap的nodes中,并输出此GraphDef.
GraphDef 和MetaDef都是 基于Google Protocol Buffer 定义的。在GraphDef 中主要以node(NodeDef) 来保存模型。具体的下次有机会在聊。
0 0
- 输出TensorFlow中checkpoint内变量的几种方法
- 输出TensorFlow中checkpoint内变量的几种方法
- php变量输出的几种方式
- PHP调试中常用的几种输出方法
- PHP调试中常用的几种输出方法
- PHP调试中常用的几种输出方法
- Java中格式化输出的几种方法
- awk引用shell中变量的几种方法
- PHP中判断变量为空的几种方法
- PHP中判断变量为空的几种方法
- PHP中判断变量为空的几种方法
- PHP中判断变量为空的几种方法
- awk引用shell中变量的几种方法
- python 在字符串中使用变量的几种方法
- matlab中变量输出的方法
- Tensorflow: 从checkpoint文件中读取tensor
- 查看TensorFlow checkpoint文件中的变量名和对应值
- 倒序输出的几种简单方法
- POJ 2393 Yogurt factory
- linux 如何改变文件属性与权限
- 2016年美团校招笔试题
- leetcode 148. Sort List 解题报告
- SQL盲注
- 输出TensorFlow中checkpoint内变量的几种方法
- 2016年360校招笔试题
- Oracle存储过程、存储函数
- 导航界面的搭建(完整demo)
- Java垃圾收集算法
- Quartz.NET教程_Lesson 11&Lesson 12(完)
- 修桥问题
- Field requires API level 5 (current min is 1): android.util.Pair#first
- 【安卓学习之常见问题】 ScrollView与其他组件的冲突问题