tensorflow模型相关

来源:互联网 发布:windows7摄像头软件 编辑:程序博客网 时间:2024/05/19 18:15

最近学习tensorflow,对其中的模型的导入导出一直比较困惑,因此花了些力气研究了一下,最后归纳整理到本篇博客。

变量保存和恢复

       在tensorflow中,变量用来存储和更新参数。变量创建时可以赋予name和初始值,并且在执行模型的其他操作之前必须对变量进行初始化。比较简单的一个方法是添加一个对初始化所有变量的操作,在使用模型前先执行这个操作。比如:

#添加变量初始化操作init_op = tf.initialize_all_variables()#执行模型前先执行初始化操作with tf.Session() as sess:    # Run the init operation.    sess.run(init_op)    sess.run(...)

CheckPoint File

    Checkpoint文件是用来保存Graph中定义的变量的二进制文件,包含了从变量名和变量值的映射关系。

保存变量

       在tensorflow中保存和回复变量的方法是使用tf.train.Saver对象,利用Saver构造器可以给graph的变量添加save和restore的ops,将变量保存或从磁盘读取。

下面是保存变量的一个例子:

# 创建变量v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# 添加一个变量初始化的opinit_op = tf.initialize_all_variables()# 添加一个ops 保存并恢复全部变量saver = tf.train.Saver()# 创建一个Session执行Graphwith tf.Session() as sess:    sess.run(init_op)    # Do some work with the model.    ..    # 保存变量    save_path = saver.save(sess, "/tmp/model.ckpt")    print "Model saved in file: ", save_path

恢复变量

  同样用Saver可以恢复变量,恢复时不需要进行初始化,但必须提前声明与恢复数据匹配的变量来接收数据

# 创建变量.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")saver = tf.train.Saver()with tf.Session() as sess:    # Restore variables from disk.    saver.restore(sess, "/tmp/model.ckpt")     print "Model restored."

0 0
原创粉丝点击