Tensorflow Save
来源:互联网 发布:买隐形眼镜必知常识 编辑:程序博客网 时间:2024/06/06 09:54
保存为四个文件:
my-model.ckpt.meta 保存整个计算图的结构
my-model.ckpt.data-* 保存模型中每个变量的取值
my-model.ckpt.index
checkpoint 记录目录下所有模型文件列表
.ckpt模型 图结构.meta与变量值.ckpt分离
from __future__ import print_functionimport tensorflow as tfimport numpy as np'''*********************自定义图运算******************''''''*********************自定义图运算******************''''''*********************自定义图运算******************''''''#**********************************************在一张图、会话中存入 再载入 变量************************************tf.reset_default_graph()W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess: sess.run(init) print("weights:", sess.run(W)) print("biases:", sess.run(b)) print(W.name,b.name) save_path = saver.save(sess, "my_net/save_net.ckpt") print("Save to path: ", save_path) saver.restore(sess, "my_net/save_net.ckpt") print("weights:", sess.run(W)) print("biases:", sess.run(b)) print(W.name,b.name)tf.reset_default_graph() #不然以下的w名称为 weight_1 ''''''#***********************************************在不同图、会话中存入 载入变量**************************************#----------------------------------save------------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights') #导出张量名为weights:0 计算节点名weights b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')print(W.name)init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "my_net/save_net.ckpt") print("Save to path: ", save_path)#---------------------------------reload-----------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!#1 WW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights2")W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")#print(W.name)# 自己定义图上运算 参数无需初始化 而将值直接按名称加载进来saver = tf.train.Saver()#2 saver = tf.train.Saver([W]) # **************************[对应名称的变量的张量] 列表形式只获取部分变量#1 saver = tf.train.Saver({'weights':WW}) #***********************{‘原名’: }形式 重命名变量 将原名weight的值放入WW中 %名字无需加上 :0 部分 --计算节点with tf.Session() as sess: # 提取变量 saver.restore(sess, "my_net/save_net.ckpt") #print("weights:", sess.run(W)) print("biases:", sess.run(W)) ''' '''*********************非自定义图运算******************''''''*********************非自定义图运算******************''''''*********************非自定义图运算******************'''#**********************************************直接加载图而无需重复定义图上的运算***********************************tf.reset_default_graph() init = tf.global_variables_initializer()saver = tf.train.import_meta_graph("C:/Users/Administrator/Desktop/my_net/save_net.ckpt.meta") #加载图 同时 下面导入变量值with tf.Session() as sess: sess.run(init) # 提取变量 saver.restore(sess, "my_net/save_net.ckpt") #通过张量的名称来获取张量 )#Tensor names must be of the form "<op_name>:<output_index>". print(sess.run(tf.get_default_graph().get_tensor_by_name('weights:0'))) #%名字需加上 :0 部分 因为是获取张量
.pb模型 freeze的模型,该模型已经是包含图和相应的参数了
import tensorflow as tffrom tensorflow.python.framework import graph_util '''*********************非自定义图运算******************''''''*********************非自定义图运算******************''''''*********************非自定义图运算******************''' ''' #***********************************************在不同图、会话中存入 载入变量**************************************#----------------------------------save------------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights') #导出张量名为weights:0 计算节点名weights b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')print(W.name)init = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init) graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,==输入层到输出层的计算过程 output_graph_def = graph_util.convert_variables_to_constants(sess, #计算图中的变量及其取值通过常量的方式保存于一个文件中 graph_def, ['weights']) ##需要保存【计算节点】的名字 %舍去无用节点 保存该节点下子图及变量值 with tf.gfile.GFile("model/w.pb", 'wb') as f: #通过 tf.gfile.GFile 进行模型持久化 f.write(output_graph_def.SerializeToString()) # 序列化输出'''#---------------------------------reload-----------------from tensorflow.python.platform import gfile tf.reset_default_graph() #!!!!!!!!!!!!!!!!! with tf.Session() as sess: model_filename = "Model/combined_model.pb" with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=['weights:0']) #得输出节点的值--【张量】 print(sess.run(result)) # [array([ 3.], dtype=float32)]
参考:
http://blog.csdn.net/marsjhao/article/details/72829635 书译
http://blog.csdn.net/michael_yt/article/details/74737489
阅读全文
0 0
- Tensorflow Save
- Tensorflow saver(save weight)
- Tensorflow-save model
- tensorflow 变量 Save类
- Tensorflow的save和load
- tensorflow模型save和restore
- tensorflow保存变量出现错误(提示不能save)
- tensorflow中tfrecords文件的save和read
- A quick complete tutorial to save and restore Tensorflow models
- A functional example for save and load model from Tensorflow
- Tensorflow: 保存和复原模型(save and restore)
- A quick complete tutorial to save and restore Tensorflow models
- Tensorflow: 保存和复原模型(save and restore)
- save
- Save
- save
- save
- save+
- 使用GDB在ARM上进行开发调试
- 算法系列——Perfect Squares
- GDOI7.6~7.15模拟总结
- Android常用控件(Widget)
- [bzoj3064]Tyvj 1518 CPU监控 线段树&排行榜垫底留念
- Tensorflow Save
- ContentProvider工作机制
- BZOJ 1145: [CTSC2008]图腾totem 数据结构维护,思维题
- 【in_array和array_search】PHP中的in_array和array_search【原创】
- HDU 1001
- HTML(进阶)
- 服务器内部跳转(请求转发)和请求重定向的区别
- 我的文章被推荐到CSDN首页
- java单例模式