TensorFlow学习笔记(五)

来源:互联网 发布:卖家网数据套餐 编辑:程序博客网 时间:2024/05/23 12:36

模型保存与载入

        在深度学习或强化学习中,我们训练一个模型常常需要较长的时间,因而我们萌生了想要将模型记录下来的想法,如何做到呢?下面我们来学习两种方法解决这一问题,参考自TensorFlow Programmers' Guide。

        1.最最最最基本的方法

        利用tf.train.Saver类实现模型的保存与载入。tf.train.Saver类的构造函数为所有的(或指定的)变量在graph中加入了save和restore的ops,并提供了运行这些ops的方法,我们只需要指定写入或读取的文件路径即可。

        a)保存:

        首先定义变量,然后定义一个tf.train.Saver类型对象,利用Saver对象的save方法对模型进行存储,官网的示例如下:

# Create some variables.v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess:sess.run(init_op)# Do some work with the model.inc_v1.op.run()dec_v2.op.run()# Save the variables to disk.save_path = saver.save(sess, "/tmp/model.ckpt")print("Model saved in file: %s" % save_path)

        b)载入
        与保存一样,也需要先定义变量,然后利用Saver对象的restore方法对模型进行载入:

tf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:# Restore variables from disk.saver.restore(sess, "/tmp/model.ckpt") print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())
        在上面的示例中我们首先将v1,v2变量按照name进行存储,格式为ckpt,也即checkpoint文件格式,然后利用restore()函数对指定的参数进行读取,也是按照name进行读取的。如果在对tf.train.Saver()进行初始化时,如果没有传入任何的参数,则默认为记录graph中的所有变量,所有的变量都按照变量被创建时传入的参数进行存储,这样做一个好处,比方说我们某次训练的模型中有一个变量名字是"weights",但是我们想要在载入后的模型中将其命名为"params",就很happy了。此外,我们还可以只保存/重建部分变量的值,或者说,我们想要将某个模型作为当前模型的部分模型,例如将一个五层的模型作为当前模型(六层)的前面五层,贴近DRL来说,比如我们可以用ImageNet比赛的某些开源模型作为视觉部分的预网络,然后对整个DRL模型进行训练。

        我们可以在构造tf.train.Saver()对象时传入一些参数来指定save和load的变量及其名字:

        1)传入一个变量列表,这些变量将按照它们自己的names进行存储;

        2)传入Python字典,其中keys是names,values是变量值。

tf.reset_default_graph()# Create some variables.v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)# Add ops to save and restore only `v2` using the name "v2"saver = tf.train.Saver({"v2": v2})# Use the saver object normally after that.with tf.Session() as sess:  # Initialize v1 since the saver will not.  v1.initializer.run()  saver.restore(sess, "/tmp/model.ckpt")  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())
        下面列出一些值得注意的点:

        1)我们可以创建多个Saver对象用于存储模型中不同的变量子集,同一个变量可以被存入多个Saver类型对象中;

        2)To inspect the variables in a checkpoint, you can use the inspect_checkpoint library, particularly the print_tensors_in_checkpoint_file function;

        3)Saver类型对象默认使用tf.Variable.name属性作为所存变量的名字,然而,我们在创建Saver类型对象时,可能会想要去自己给所存对象命名,所以Saver类型提供了这一功能。

        4)我们可以简单地将restore理解为模型参数的初始化,毕竟我们要自己去重新定义一个模型,然后对参数进行载入。


        上面介绍了按照变量进行存储/重建的方法,该方法有它的优点,比如可以存储/重建变量的subsets,但也有它的缺点。如果我们想要重建某一个模型,必须要将整个模型的结构重新定义一遍,然后载入对应名字变量的值,与我们所期望的不一致,我们当然想要几句代码就能从一个文件中直接读取出整个模型,包括其结构,而不是又自己去定义一遍模型,此时,就有了另外一种方法,tf.train.import_meta_graph。

        下面我们举例进行说明:

import tensorflow as tfw1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')saver = tf.train.Saver()sess = tf.Session()sess.run(tf.global_variables_initializer())saver.save(sess, 'my_test_model')
        其中我们使用了Saver类型中的save方法,该方法将隐式地调用export_meta_graph函数来将模型进行导出,在使用该函数时,我们可以传入迭代步数,比如:

saver.save(sess,'my_test_model', global_step=1000)
        然后,我们对模型进行导入:

sess = tf.Session()new_saver = tf.train.import_meta_graph('my-model.meta')new_saver.restore(sess, tf.train.latest_checkpoint('./'))all_vars = tf.get_collection('vars')for v in all_vars:    v_ = sess.run(v)    print(v_)
        我么首先利用import_meta_graph方法将模型图导入,然后利用restore方法将图对应的参数的值导入至当前的script,get_collection返回一个list,想要用哪个自己获取就是了。

        下面介绍一下调用save方法得到的几个文件:

        1)meta图

        这是一个拟定的缓存,包含了该TensorFlow图的完整信息,比如所有变量等,文件以.meta结束。

        2)检查点文件

        该文件是一个二进制文件,包含所有的权重、偏移、梯度和所有其他存储的变量或值,在老版本中以.ckpt结束,但是在0.11版本之后以这个形式出现了,而是包含如下文件:

        a)mymodel.data-00000-of-00001

        b)mymodel.index

        3)checkpoint**

        其中.data文件包含训练变量。

        最后,给出import_meta_graph方法的函数签名:

import _meta_graph(meta_graph_or_file,clear_devices=False,import_scope=None,**kwargs)
        该方法可以从文件中将保存的graph中的所有节点加载到当前的default graph中,并返回一个saver,也就是说,我们在保存的时候,不仅仅将变量的值保存下来,还将graph中的节点保存下来了,即模型的结构也得以保存,这一点在上面几个文件介绍中可以看出来。
        最后举一个有意义的例子:

with tf.Session() as sess:        new_saver=tf.train.import_data_graph('my-save-dir/my-model-10000.meta')        new_saver.restore(sess,'my-save-dir/my-model-10000')        y=tf.get_collection('predict_network')[0]        graph=tf.get_default_graph()        input_x=graph.get_operation_by_name('input_x').outputs[0]        keep_prob=graph.get_operation_by_name('keep_prob').output[0]        sess.run(y,feed_dict)={input_x:....,keep_prob:1.0}
        我们可以利用该网络进行预测,这就很开心了。







原创粉丝点击