Tensorflow 保存和加载模型

来源:互联网 发布:python宝典 编辑:程序博客网 时间:2024/06/05 05:32
import tensorflow as tfdef save_model():    v1 = tf.Variable(tf.constant(1.0,shape=[1]),name='v1')    v2 = tf.Variable(tf.constant(2.0,shape=[1]),name='v2')    res = tf.add(v1,v2,name='add_res')    saver = tf.train.Saver()           with tf.Session() as sess:        sess.run(tf.global_variables_initializer())        saver.save(sess,'./save/model.ckpt')def restore_model():    saver = tf.train.import_meta_graph('./save/model.ckpt.meta')            with tf.Session() as sess:        saver.restore(sess,'./save/model.ckpt')        print (sess.run(tf.get_default_graph().get_tensor_by_name('add_res:0')))if __name__=='__main__':        save_mode() #第二次运行本程序时,需要注释这一句。    restore_model()    

原创粉丝点击