Tensorflow学习: 保存变量和网络

来源:互联网 发布:淘宝小号查询源码 编辑:程序博客网 时间:2024/04/29 13:24

本文内容:
1. 保存网络
2. 在保存网络的路径下保存变量

import tensorflow as tfimport numpy as np### Save to file# remerber to define the same dtype and shape when restore#W = tf.Variable([[1,2,3],[3,4,5]], dtype = tf.float32, name = 'weights')#b = tf.Variable([[1,2,3]], dtype = tf.float32, name = 'biases')##init = tf.global_variables_initializer()##saver = tf.train.Saver()##with tf.Session() as sess:#    sess.run(init)#    save_path = saver.save(sess, "logs/save_net.ckpt")#    print("Save to path: ", save_path)### restore variables# redefine the same shape and same type for your variablesW = tf.Variable(np.arange(6).reshape((2,3)), dtype = tf.float32, name = 'weights')b = tf.Variable(np.arange(3).reshape((1,3)), dtype = tf.float32, name = 'biases')saver = tf.train.Saver()with tf.Session() as sess:    saver.restore(sess, "logs/save_net.ckpt")    print("Weights: ", sess.run(W))    print("baises: ", sess.run(b))
0 0
原创粉丝点击