5.1 Tensorflow:图与模型的加载与存储

来源:互联网 发布:现在开童装淘宝店 编辑:程序博客网 时间:2024/06/08 05:53


自己学Tensorflow,现在看的书是《TensorFlow技术解析与实战》,不得不说这书前面的部分有点坑,后面的还不清楚.图与模型的加载写的不清楚,书上的代码还不能运行=- =,真是BI….咳咳.之后还是开始了查文档,翻博客的填坑之旅



# 一般而言我们是构建模型之后,session运行,但是这次不同之处在于我们是构件好之后存储了模型# 然后在session中加载存储好的模型,再运行import tensorflow as tfimport osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 声明两个变量v1 = tf.Variable(tf.random_normal([1, 2]), name='v1')v2 = tf.Variable(tf.random_normal([2, 3]), name='v2')init_op = tf.global_variables_initializer() # 初始化全部变量# saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型saver = tf.train.Saver()# 只存储图if not os.path.exists('save/model.meta'):    saver.export_meta_graph('save/model.meta')print()with tf.Session() as sess:    sess.run(init_op)    print('v1:', sess.run(v1)) # 打印v1、v2的值一会读取之后对比    print('v2:', sess.run(v2))    saver_path = saver.save(sess, 'save/model.ckpt')  # 将模型保存到save/model.ckpt文件    print('Model saved in file:', saver_path)print()with tf.Session() as sess:    saver.restore(sess, 'save/model.ckpt') # 即将固化到硬盘中的模型从保存路径再读取出来,这样就可以直接使用之前训练好,或者训练到某一阶段的的模型了    print('v1:', sess.run(v1)) # 打印v1、v2的值和之前的进行对比    print('v2:', sess.run(v2))    print('Model Restored')print()# 只加载图,saver = tf.train.import_meta_graph('save/model.ckpt.meta')with tf.Session() as sess:    saver.restore(sess, 'save/model.ckpt')    # 通过张量的名称来获取张量,也可以直接运行新的张量    print('v1:', sess.run(tf.get_default_graph().get_tensor_by_name('v1:0')))    print('v2:', sess.run(tf.get_default_graph().get_tensor_by_name('v2:0')))


v1: [[-0.78213912 -0.72646964]]v2: [[-0.36301413 -0.99892306  0.21593148] [-1.09692276 -0.06931346  0.19474344]]Model saved in file: save/model.ckptv1: [[-0.78213912 -0.72646964]]v2: [[-0.36301413 -0.99892306  0.21593148] [-1.09692276 -0.06931346  0.19474344]]Model Restoredv1: [[-0.78213912 -0.72646964]]v2: [[-0.36301413 -0.99892306  0.21593148] [-1.09692276 -0.06931346  0.19474344]]










Protocol Buffer.

Protocol Buffer定义的。MetaGraphDef




# 在1000次迭代时存储saver.save(sess, 'my_test_model',global_step=1000)




#saves a model every 2 hours and maximum 4 latest models are saved.saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)



  def __init__(self,               var_list=None,               reshape=False,               sharded=False,               max_to_keep=5,               keep_checkpoint_every_n_hours=10000.0,               # 默认时间是一万小时,有趣               # 但我们只争朝夕               name=None,               restore_sequentially=False,               saver_def=None,               builder=None,               defer_build=False,               allow_empty=False,               write_version=saver_pb2.SaverDef.V2,               pad_step_number=False,               save_relative_paths=False):    """Creates a `Saver`.    The constructor adds ops to save and restore variables.    `var_list` specifies the variables that will be saved and restored. It can    be passed as a `dict` or a list:    * A `dict` of names to variables: The keys are the names that will be      used to save or restore the variables in the checkpoint files.    * A list of variables: The variables will be keyed with their op name in      the checkpoint files.    For example:    ```python    v1 = tf.Variable(..., name='v1')    v2 = tf.Variable(..., name='v2')    # Pass the variables as a dict:    saver = tf.train.Saver({'v1': v1, 'v2': v2})    # Or pass them as a list.    saver = tf.train.Saver([v1, v2])    # Passing a list is equivalent to passing a dict with the variable op names    # as keys:    saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})    ```    The optional `reshape` argument, if `True`, allows restoring a variable from    a save file where the variable had a different shape, but the same number    of elements and type.  This is useful if you have reshaped a variable and    want to reload it from an older checkpoint.    The optional `sharded` argument, if `True`, instructs the saver to shard    checkpoints per device.    Args:      var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping        names to `SaveableObject`s. If `None`, defaults to the list of all        saveable objects.      reshape: If `True`, allows restoring parameters from a checkpoint        where the variables have a different shape.      sharded: If `True`, shard the checkpoints, one per device.      max_to_keep: Maximum number of recent checkpoints to keep.        Defaults to 5.      keep_checkpoint_every_n_hours: How often to keep checkpoints.        Defaults to 10,000 hours.      name: String.  Optional name to use as a prefix when adding operations.      restore_sequentially: A `Bool`, which if true, causes restore of different        variables to happen sequentially within each device.  This can lower        memory usage when restoring very large models.      saver_def: Optional `SaverDef` proto to use instead of running the        builder. This is only useful for specialty code that wants to recreate        a `Saver` object for a previously built `Graph` that had a `Saver`.        The `saver_def` proto should be the one returned by the        `as_saver_def()` call of the `Saver` that was created for that `Graph`.      builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.        Defaults to `BaseSaverBuilder()`.      defer_build: If `True`, defer adding the save and restore ops to the        `build()` call. In that case `build()` should be called before        finalizing the graph or using the saver.      allow_empty: If `False` (default) raise an error if there are no        variables in the graph. Otherwise, construct the saver anyway and make        it a no-op.      write_version: controls what format to use when saving checkpoints.  It        also affects certain filepath matching logic.  The V2 format is the        recommended choice: it is much more optimized than V1 in terms of        memory required and latency incurred during restore.  Regardless of        this flag, the Saver is able to restore from both V2 and V1 checkpoints.      pad_step_number: if True, pads the global step number in the checkpoint        filepaths to some fixed width (8 by default).  This is turned off by        default.      save_relative_paths: If `True`, will write relative paths to the        checkpoint state file. This is needed if the user wants to copy the        checkpoint directory and reload from the copied directory.    Raises:      TypeError: If `var_list` is invalid.      ValueError: If any of the keys or values in `var_list` are not unique.    """