170615 windows 下 tensorflow1.2.0rc2 模型的保存与恢复

来源:互联网 发布:电信网络电视机顶盒 编辑:程序博客网 时间:2024/06/03 14:31

1. save and restore 英文教程,极力推荐
2. save and restore 中文教程,翻译地道
3. save and restore 简书教程,老版可用
4. save and restore youtube教程,视频操作
5. save and restore youtube代码),下载运行
6. save and restore python3.6 @ windows support - 官方说明
7. Stack Over 热心解答

首先感谢Ankit答疑和Hui的热心解答,今天总算弄懂如何保存与恢复模型了!官方的解释,连个具体的例子都没看到,中间只有个框架,然后就是各种是省略号,只想说工作真心做得不好!
这里写图片描述
推荐大家看完1,7后可运行下面两个程序:
注意: 在前述参考文章里面,都没有提到新建文件夹用于保存于加载模型,这个可能与系统有关,may be 在linux下模型可以直接在程序路径下保存与加载,但是在windows下必须要新建一个单独的文件夹才行。所以在保存与加载时要指明文件夹路径。

save.py

import tensorflow as tf#Prepare to feed input, i.e. feed_dict and placeholdersw1 = tf.placeholder("float", name="w1")w2 = tf.placeholder("float", name="w2")b1= tf.Variable(2.0,name="bias")feed_dict ={w1:4,w2:8}#Define a test operation that we will restorew3 = tf.add(w1,w2)w4 = tf.multiply(w3,b1,name="op_to_restore")sess = tf.Session()sess.run(tf.global_variables_initializer())#Create a saver object which will save all the variablessaver = tf.train.Saver()#Run the operation by feeding inputprint (sess.run(w4,feed_dict))#Prints 24 which is sum of (w1+w2)*b1 #Now, save the graphsaver.save(sess, 'op\\my_test_model',global_step=1000)

restore.py

pythonimport tensorflow as tfsess=tf.Session()    #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('op\\my_test_model-1000.meta')#saver.restore(sess,'op\\my_test_model-1000') # restore the data according to the saving namesaver.restore(sess,tf.train.latest_checkpoint('op\\')) # restore the data according to the latest check point# Now, let's access and create placeholders variables and# create feed-dict to feed new data#graph = tf.get_default_graph()w1 = graph.get_tensor_by_name("w1:0")w2 = graph.get_tensor_by_name("w2:0")feed_dict ={w1:13.0,w2:17.0}###Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")#print (sess.run(op_to_restore,feed_dict))##This will print 60 which is calculated ##using new values of w1 and w2 and saved value of b1. #Add more to the current graphadd_on_op = tf.multiply(op_to_restore,2)print (sess.run(add_on_op,feed_dict))#This will print 120.
阅读全文
0 0
原创粉丝点击