Tensorflow Save

来源:互联网 发布:买隐形眼镜必知常识 编辑:程序博客网 时间:2024/06/06 09:54

保存为四个文件:

my-model.ckpt.meta          保存整个计算图的结构

my-model.ckpt.data-*        保存模型中每个变量的取值

my-model.ckpt.index

checkpoint                          记录目录下所有模型文件列表


.ckpt模型     图结构.meta与变量值.ckpt分离

from __future__ import print_functionimport tensorflow as tfimport numpy as np'''*********************自定义图运算******************''''''*********************自定义图运算******************''''''*********************自定义图运算******************''''''#**********************************************在一张图、会话中存入 再载入 变量************************************tf.reset_default_graph()W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init)    print("weights:", sess.run(W))    print("biases:", sess.run(b))    print(W.name,b.name)    save_path = saver.save(sess, "my_net/save_net.ckpt")    print("Save to path: ", save_path)    saver.restore(sess, "my_net/save_net.ckpt")    print("weights:", sess.run(W))    print("biases:", sess.run(b))    print(W.name,b.name)tf.reset_default_graph()  #不然以下的w名称为 weight_1     ''''''#***********************************************在不同图、会话中存入 载入变量**************************************#----------------------------------save------------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')print(W.name)init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init)    save_path = saver.save(sess, "my_net/save_net.ckpt")    print("Save to path: ", save_path)#---------------------------------reload-----------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!#1 WW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights2")W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")#print(W.name)# 自己定义图上运算 参数无需初始化 而将值直接按名称加载进来saver = tf.train.Saver()#2 saver = tf.train.Saver([W])              # **************************[对应名称的变量的张量] 列表形式只获取部分变量#1 saver = tf.train.Saver({'weights':WW})   #***********************{‘原名’: }形式 重命名变量  将原名weight的值放入WW中     %名字无需加上 :0 部分  --计算节点with tf.Session() as sess:       # 提取变量    saver.restore(sess, "my_net/save_net.ckpt")    #print("weights:", sess.run(W))    print("biases:", sess.run(W))    '''  '''*********************非自定义图运算******************''''''*********************非自定义图运算******************''''''*********************非自定义图运算******************'''#**********************************************直接加载图而无需重复定义图上的运算***********************************tf.reset_default_graph() init = tf.global_variables_initializer()saver = tf.train.import_meta_graph("C:/Users/Administrator/Desktop/my_net/save_net.ckpt.meta")  #加载图 同时 下面导入变量值with tf.Session() as sess:    sess.run(init)    # 提取变量    saver.restore(sess, "my_net/save_net.ckpt")    #通过张量的名称来获取张量 )#Tensor names must be of the form "<op_name>:<output_index>".    print(sess.run(tf.get_default_graph().get_tensor_by_name('weights:0')))                              #%名字需加上 :0 部分   因为是获取张量


.pb模型   freeze的模型,该模型已经是包含图和相应的参数了

import tensorflow as tffrom tensorflow.python.framework import graph_util  '''*********************非自定义图运算******************''''''*********************非自定义图运算******************''''''*********************非自定义图运算******************'''    ''' #***********************************************在不同图、会话中存入 载入变量**************************************#----------------------------------save------------------tf.reset_default_graph() #!!!!!!!!!!!!!!!!!W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')             #导出张量名为weights:0  计算节点名weights    b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')print(W.name)init = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(init)    graph_def = tf.get_default_graph().as_graph_def()    #得到当前的图的 GraphDef 部分,==输入层到输出层的计算过程    output_graph_def = graph_util.convert_variables_to_constants(sess,   #计算图中的变量及其取值通过常量的方式保存于一个文件中                                                          graph_def, ['weights'])   ##需要保存【计算节点】的名字    %舍去无用节点 保存该节点下子图及变量值      with tf.gfile.GFile("model/w.pb", 'wb') as f:  #通过 tf.gfile.GFile 进行模型持久化        f.write(output_graph_def.SerializeToString())   # 序列化输出'''#---------------------------------reload-----------------from tensorflow.python.platform import gfile  tf.reset_default_graph() #!!!!!!!!!!!!!!!!!  with tf.Session() as sess:      model_filename = "Model/combined_model.pb"      with gfile.FastGFile(model_filename, 'rb') as f:          graph_def = tf.GraphDef()          graph_def.ParseFromString(f.read())        result = tf.import_graph_def(graph_def, return_elements=['weights:0'])   #得输出节点的值--【张量】    print(sess.run(result)) # [array([ 3.], dtype=float32)]  


参考:

http://blog.csdn.net/marsjhao/article/details/72829635  书译

http://blog.csdn.net/michael_yt/article/details/74737489

原创粉丝点击