Tensorflow基础:模型持久化

来源:互联网 发布:java项目目录结构 编辑:程序博客网 时间:2024/06/15 06:41

为了让训练结果可以复用,需要将训练得到的神经网络模型持久化。本文将介绍通过Tensorflow程序来持久化一个训练好的模型,并从持久化之后的模型文件中还原被保存的模型。

持久化代码实现

Tensorflow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver类。以下代码给出了保存Tensorflow计算图的方法。

保存模型

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1], name="v1"))v2 = tf.Variable(tf.constant(2.0, shape=[1], name="v2"))result = v1 + v2init_op = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:    sess.run(init_op)    saver.save(sess, ".\model\model.ckpt")

这段代码中,通过saver.save函数将Tensorflow模型保存到了model.ckpt文件中。实际上,保存了三个文件:

model

  1. model.ckpt.meta:它保存了Tensorflow计算图的结构,这里可以简单理解为神经网络的网络结构
  2. model.ckpt:它保存了Tensorflow程序中每一个变量的取值
  3. checkpoint:它保存了一个目录下所有的模型文件列表

加载模型

以下代码给出了加载这个已经保存的Tensorflow模型的方法:

import tensorflow as tfv1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")result = v1 + v2#init_op = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess:    #sess.run(init_op)    #saver.save(sess, ".\model\model.ckpt")    saver.restore(sess, ".\model\model.ckpt")    print(sess.run(result))

这段加载模型的代码基本上和保存模型的代码是一样的。唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。

如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。以下代码给出了一个样例:

import tensorflow as tfsaver = tf.train.import_meta_graph(".\model\model.ckpt.meta")with tf.Session() as sess:    saver.restore(sess, ".\model\model.ckpt")    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

在上面给出的程序中,默认保存和加载了Tensorflow计算图上定义的全部变量。但是有时只需要保存或者加载部分变量。可以在声明tf.train.Saver类是,提供一个列表来指定需要保存或者加载的变量。例:saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

原创粉丝点击