存储与读取变量

来源:互联网 发布:帝国仿内涵吧网源码 编辑:程序博客网 时间:2024/05/15 15:14

1.存储模型并且指定存储的变量

#coding=utf-8import tensorflow as tfx = tf.placeholder(shape=[1], dtype=tf.float32, name='xx')variable_0 = tf.get_variable('v0', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))variable_1 = tf.get_variable('v1', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))output = tf.multiply(x, variable_1, name='mul')saver = tf.train.Saver({'x': variable_0, 'y': variable_1})#指定只存储以上variable_0, variable_1# saver = tf.train.Saver()initial_op = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(initial_op)    saver.save(sess, './checkpoint/model0.ckpt')
2.读取模型与变量
#coding=utf-8import tensorflow as tfx = tf.placeholder(shape=[1], dtype=tf.float32, name='xx')variable_2 = tf.get_variable('v2', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))variable_3 = tf.get_variable('v3', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))output = tf.multiply(x, variable_2, name='mul')saver = tf.train.Saver({'x': variable_2, 'y': variable_3})#指定将值载入哪个变量中# saver = tf.train.Saver()initial_op = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(initial_op)    saver.restore(sess, './checkpoint/model0.ckpt')    print(sess.run(variable_2))
结果:[ 1.50878465]

3.直接载入模型
#coding=utf-8import tensorflow as tfsaver = tf.train.import_meta_graph('./checkpoint/model0.ckpt.meta')#载入图的结构# saver = tf.train.Saver()initial_op = tf.global_variables_initializer()with tf.Session() as sess:    sess.run(initial_op)    saver.restore(sess, './checkpoint/model0.ckpt')#载入值    print(tf.trainable_variables())    a = tf.get_default_graph().get_tensor_by_name('v0:0')    print(sess.run(a))
结果:
[ 1.50878465]


 
原创粉丝点击