Tensorflow: 保存和复原模型(save and restore)
来源:互联网 发布:淘宝客网站要备案吗 编辑:程序博客网 时间:2024/05/18 00:40
目前我主要看到了两种方法来保存和复原tensorflow model,先总结一下:
MetaGraph
这种就是我们经常看到的 tf.train.Saver
对应的东西。使用这种方法保存模型,会产生两种文件。
- meta: 里面存储的是整个graph的定义
- checkpoint: 这里保存的是
variable
的状态。
这里通过如下的方式保存一个模型
checkpoint_dir = "mysaver"# first creat a simple graphgraph = tf.Graph()#define a simple graphwith graph.as_default(): x = tf.placeholder(tf.float32,shape=[],name='input') y = tf.Variable(initial_value=0,dtype=tf.float32,name="y_variable") update_y = y.assign(x) saver = tf.train.Saver(max_to_keep=3) init_op = tf.global_variables_initializer()# train the model and save the model every 4000 iterations.sess = tf.Session(graph=graph)sess.run(init_op)for i in range(1,10000): y_result = sess.run(update_y,feed_dict={x:i}) if i %4000 == 0: saver.save(sess,checkpoint_dir,global_step=i)
这些是产生的文件
checkpointmysaver-4000.data-00000-of-00001mysaver-4000.indexmysaver-4000.metamysaver-8000.data-00000-of-00001mysaver-8000.indexmysaver-8000.meta
稍后我们可以复原model
tf.reset_default_graph()restore_graph = tf.Graph()with tf.Session(graph=restore_graph) as restore_sess: restore_saver = tf.train.import_meta_graph('mysaver-8000.meta') restore_saver.restore(restore_sess,tf.train.latest_checkpoint('./')) print(restore_sess.run("y_variable:0"))
上面这段python代码的输出如下:
INFO:tensorflow:Restoring parameters from ./mysaver-80008000.0
因为最新的checkpoint文件是在 8000th iterations保存的,所以当model复原后 y_variable的值是 80000
SavedModel
还有一种保存模型的方法就是 SavedModel
。
这种方法我是在看tensorflow servicing的时候看到的,个人的感觉,这是一种更适合部署的方法。暂时没有去研究tensorflow servicing。但是我看很多代码都使用到了通过这种方式保存的文件。比如imagenet example。所以这里着重介绍怎么使用从别的地方拿到的SavedModel文件。
建立 SavedModel
主要分为三部
* 建立一个 tf.saved_model.builder.SavedModelBuilder
.
* 使用刚刚建立的 builder把当前的graph和variable添加进去:SavedModelBuilder.add_meta_graph_and_variables(...)
* 可以使用 SavedModelBuilder.add_meta_graph
添加多个meta graph
复原 SavedModel
这个需要通过这个 model 来完成的:tf.saved_model.loader
通过命令来查看和执行SavedModel
上面的通过编程的方式来建立和复原SavedModel
, 我现在基本上不需要发布模型给别人用,但是经常想使用一下别人已经训练好的模型。当拿到别人的模型的时候,需要知道怎么使用。官方提供了一个工具:saved_model_cli
,这个工具包含了 show 和 run 两类命令
感兴趣的同学可以查看官方文档 或者这篇博客对应的 jupyter notebook
可视化 SavedModel
我们知道google提供 TensorBoard给我们可视化的调试tensorflow, tensorboard一个最基本的功能就是把graph展示出来。但是有时候我们拿到别人 SavedModel
, 我们需要把这个model跑一遍,产生summary文件才能在tensorboard里面看。google deepdream 参考代码里面提供了一个很方便的代码可以让我们快速的把graph展示出来。代码如下, 这个代码是我也放到我的github了,大家也可以直接去看google deepdram 参考代码
# these function is copied from google deepdream example codeimport numpy as npfrom IPython.display import clear_output, Image, display, HTMLdef strip_consts(graph_def, max_const_size=32): """Strip large constant values from graph_def.""" strip_def = tf.GraphDef() for n0 in graph_def.node: n = strip_def.node.add() n.MergeFrom(n0) if n.op == 'Const': tensor = n.attr['value'].tensor size = len(tensor.tensor_content) if size > max_const_size: tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size) return strip_defdef rename_nodes(graph_def, rename_func): res_def = tf.GraphDef() for n0 in graph_def.node: n = res_def.node.add() n.MergeFrom(n0) n.name = rename_func(n.name) for i, s in enumerate(n.input): n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:]) return res_defdef show_graph(graph_def, max_const_size=32): """Visualize TensorFlow graph.""" if hasattr(graph_def, 'as_graph_def'): graph_def = graph_def.as_graph_def() strip_def = strip_consts(graph_def, max_const_size=max_const_size) code = """ <script> function load() {{ document.getElementById("{id}").pbtxt = {data}; }} </script> <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()> <div style="height:600px"> <tf-graph-basic id="{id}"></tf-graph-basic> </div> """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand())) iframe = """ <iframe seamless style="width:800px;height:620px;border:0" srcdoc="{}"></iframe> """.format(code.replace('"', '"')) display(HTML(iframe))
- Tensorflow: 保存和复原模型(save and restore)
- Tensorflow: 保存和复原模型(save and restore)
- tensorflow模型save和restore
- tensorflow训练模型保存saver和恢复restore
- tensorflow学习(4):保存模型Saver.save()的参数命名机制以及restore并创建手写字体识别引擎
- A quick complete tutorial to save and restore Tensorflow models
- A quick complete tutorial to save and restore Tensorflow models
- canvas save and restore
- tensorflow之Graph save and restore in python and c++(C++ 中使用tensorflow)
- kvm save和restore
- Canvas:save()和restore()
- vim save and restore session
- tensorflow保存 和 加载模型
- Tensorflow保存和读取模型
- Tensorflow 保存和加载模型
- 保存和读取 TensorFlow 模型
- tensorflow保存和加载模型
- 保存自定义对象的数组 save and restore an array of custom objects
- Centos下数据写入MySQL数据库汉字是????
- MySql教程Link
- 第五章——视图控制器
- fasttext初步使用
- JavaScript的一些基本方法总结
- Tensorflow: 保存和复原模型(save and restore)
- AJAX 加载数据
- linux 命令执行java文件
- linux 磁盘UUID
- 谈谈hibernate的延迟加载和openSessionInView
- osg开发预览
- 深入理解Java虚拟机笔记1: OOM实战
- Poj --1751 highways (最小生成树,kruskal算法)
- CentOS 6下快速搭建ftp服务器