tensorflow从0开始(6)——保存加载模型

来源:互联网 发布:知阅小说网_原创小说网 编辑:程序博客网 时间:2024/06/05 15:28

目的

学习tensorflow的目的是能够训练的模型,并且利用已经训练好的模型对新数据进行预测。下文就是一个简单的保存模型加载模型的过程。

保存模型

import tensorflow as tfimport osimport numpy as npfrom tensorflow.python.platform import gfileflags = tf.app.flagsFLAGS = flags.FLAGSflags.DEFINE_string('summaries_dir', '/tmp/save_graph_logs', 'Summaries directory')data = np.arange(10,dtype=np.int32)with tf.Session() as sess:  print("# build graph and run")  input1= tf.placeholder(tf.int32, [10], name="input")  output1= tf.add(input1, tf.constant(100,dtype=tf.int32), name="output") #  data depends on the input data  saved_result= tf.Variable(data, name="saved_result")  do_save=tf.assign(saved_result,output1)  tf.initialize_all_variables()  os.system("rm -rf /tmp/save_graph_logs")  merged = tf.merge_all_summaries()  train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir,                                        sess.graph)  os.system("rm -rf /tmp/load")  tf.train.write_graph(sess.graph_def, "/tmp/load", "test.pb", False) #proto  # now set the data:  result,_=sess.run([output1,do_save], {input1: data}) # calculate output1 and assign to 'saved_result'  saver = tf.train.Saver(tf.all_variables())  saver.save(sess,"checkpoint.data")

模型图示


加载模型

with tf.Session() as persisted_sess:  print("load graph")  with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:    graph_def = tf.GraphDef()    graph_def.ParseFromString(f.read())    persisted_sess.graph.as_default()    tf.import_graph_def(graph_def, name='')  print("map variables")  persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")  tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)  try:    saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!  except:pass  print("load data")  saver.restore(persisted_sess, "checkpoint.data")  # now OK  print(persisted_result.eval())  print("DONE")

显示结果


0 0
原创粉丝点击