Tensorflow saver(save weight)

来源:互联网 发布:手机音乐闪光灯软件 编辑:程序博客网 时间:2024/05/22 08:06

saver用于变量的读取操作,需要注意的是,在变量存储时,需要事先建立好一个文件夹。存储代码并不会自动新建文件夹,需要你人为手动建立,不然会报错。

主要保存代码为,建立一个saver,保存sess

saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init)    save_path = saver.save(sess, "my_net/save_net.ckpt")    print("Save to path: ", save_path)

主要存储代码为:用变量名字识别应该哪个变量接受哪个值

W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# not need init stepsaver = tf.train.Saver()with tf.Session() as sess:    saver.restore(sess, "my_net/save_net.ckpt")    print("weights:", sess.run(W))    print("biases:", sess.run(b))

完整代码如下:

变量存储代码

# View more python tutorials on my Youtube and Youku channel!!!# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg# Youku video tutorial: http://i.youku.com/pythontutorial"""Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly."""from __future__ import print_functionimport tensorflow as tfimport numpy as np# Save to file# remember to define the same dtype and shape when restoreW = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# tf.initialize_all_variables() no long valid from# 2017-03-02 if using tensorflow >= 0.12if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:    init = tf.initialize_all_variables()else:    init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init)    save_path = saver.save(sess, "my_net/save_net.ckpt")    print("Save to path: ", save_path)

读取代码

# View more python tutorials on my Youtube and Youku channel!!!# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg# Youku video tutorial: http://i.youku.com/pythontutorial"""Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly."""from __future__ import print_functionimport tensorflow as tfimport numpy as np################################################# restore variables# redefine the same shape and same type for your variablesW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# not need init stepsaver = tf.train.Saver()with tf.Session() as sess:    saver.restore(sess, "my_net/save_net.ckpt")    print("weights:", sess.run(W))    print("biases:", sess.run(b))
0 0