Tensorflow saver(save weight)
来源:互联网 发布:手机音乐闪光灯软件 编辑:程序博客网 时间:2024/05/22 08:06
saver用于变量的读取操作,需要注意的是,在变量存储时,需要事先建立好一个文件夹。存储代码并不会自动新建文件夹,需要你人为手动建立,不然会报错。
主要保存代码为,建立一个saver,保存sess
saver = tf.train.Saver()with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "my_net/save_net.ckpt") print("Save to path: ", save_path)
主要存储代码为:用变量名字识别应该哪个变量接受哪个值
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# not need init stepsaver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess, "my_net/save_net.ckpt") print("weights:", sess.run(W)) print("biases:", sess.run(b))
完整代码如下:
变量存储代码
# View more python tutorials on my Youtube and Youku channel!!!# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg# Youku video tutorial: http://i.youku.com/pythontutorial"""Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly."""from __future__ import print_functionimport tensorflow as tfimport numpy as np# Save to file# remember to define the same dtype and shape when restoreW = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')# tf.initialize_all_variables() no long valid from# 2017-03-02 if using tensorflow >= 0.12if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1: init = tf.initialize_all_variables()else: init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess: sess.run(init) save_path = saver.save(sess, "my_net/save_net.ckpt") print("Save to path: ", save_path)
读取代码
# View more python tutorials on my Youtube and Youku channel!!!# Youtube video tutorial: https://www.youtube.com/channel/UCdyjiB5H8Pu7aDTNVXTTpcg# Youku video tutorial: http://i.youku.com/pythontutorial"""Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly."""from __future__ import print_functionimport tensorflow as tfimport numpy as np################################################# restore variables# redefine the same shape and same type for your variablesW = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")# not need init stepsaver = tf.train.Saver()with tf.Session() as sess: saver.restore(sess, "my_net/save_net.ckpt") print("weights:", sess.run(W)) print("biases:", sess.run(b))
0 0
- Tensorflow saver(save weight)
- Tensorflow---monitoring,saver
- TensorFlow学习--Saver
- tensorflow学习(4):保存模型Saver.save()的参数命名机制以及restore并创建手写字体识别引擎
- tensorflow关于tf.train.Saver()
- TensorFlow之saver的用法
- Tensorflow Save
- Tensorflow系列——Saver的用法
- tensorflow saver 保存和恢复指定 tensor
- TensorFlow saver之指定变量的存取
- Tensorflow小样例-Saver模型保存读取
- tensorflow加载saver.restore目录报错
- tensorflow加载saver.restore目录报错
- TensorFlow利用saver保存和提取参数
- tensorflow 模型的保存与恢复(Saver)
- tensorflow saver restore固定的layer
- Saver
- Tensorflow-save model
- xmlspy学习之如何写仅含文本复合元素
- 润乾报表autobig标签展现、打印、导出问题总结
- oracle 数据类型
- 高斯消元
- android 多媒体和相机详解八
- Tensorflow saver(save weight)
- AlertDialog对话框小结
- H264编码技术
- Retrofit2.0使用详解
- oracle 外部表详解
- 如何查找程序中的逻辑错误
- C++访问控制符详解
- CRC32算法实现
- poj1364King_差分约束系统