TensorFlow 入门 3 ——变量管理和模型持久化
来源:互联网 发布:淘宝店有差评怎么办 编辑:程序博客网 时间:2024/06/01 10:04
变量管理
TensorFlow 提供了通过变量名称来创建或者获取一个变量的机制。通过这个机制,在不同的函数中可以直接通过变量名称来使用变量,而不需要将变量通过参数的形式到处传递。TensorFlow中通过变量名称获取变量的机制主要是通过tf.get_variable和tf.variable_scope函数实现的。
tf.get_variable
TensorFLow还提供了tf.get_variable函数来创建或者获取变量,tf.get_variable用于创建变量时,其功能和tf.Variable基本是等价的。tf.get_variable中的初始化方法(initializer)的参数和tf.Variable的初始化过程也类似,initializer函数和tf.Variable的初始化方法是一一对应的。
#以下两个定义是等价的##首先根据"v"这个名称来创建一个参数,如果创建失败(比如已经有同名的参数),那么这个程序就会报错。(防止变量重复创建)v = tf.get_variable("v", shape=[1], inittializer=tf.constant_initializer(1.0))v = tf.Variable(tf.constant(1.0, shape=[1]), name="v"))##**The best way to create a variable is to call the tf.get_variable function.**
TensorFlow中提供的initializer函数和随机数以及常数生成函数大部分是意义对应的。
tf.get_variable和tf.Variable最大的区别就在于指定变量名称的参数。对于tf.Variable函数,变量名称是一个可选的参数。对于tf.get_variable函数,变量名称是一个必填的参数,tf.get_variable会根据这个名称去创建或者获取变量。
tf.variable_scope
如果需要通过tf.get_variable获取一个已经创建的变量,需要通过tf.variable_scope函数来生成一个上下文管理器,并明确指定在这个上下文管理器中,tf.get_variable将直接获取已创建的变量。下面一段代码说明了如何通过tf.variable_scope函数来控制tf.get_variable函数获取创建过的变量。
通过tf.variable_scope函数可以控制tf.get_variable函数的语义。当tf.variable_scope函数的参数reuse=True生成上下文管理器时,该上下文管理器内的所有的tf.get_variable函数会直接获取已经创建的变量,如果变量不存在则报错;当tf.variable_scope函数的参数reuse=False或者None时创建的上下文管理器中,tf.get_variable函数则直接创建新的变量,若同名的变量已经存在则报错。
另tf.variable_scope函数是可以嵌套使用的。嵌套的时候,若某层上下文管理器未声明reuse参数,则该层上下文管理器的reuse参数与其外层保持一致。
tf.variable_scope函数提供了一个管理变量命名空间的方式。在tf.variable_scope中创建的变量,名称.name中名称前面会加入命名空间的名称,并通过“/”来分隔命名空间的名称和变量的名称。tf.get_variable(“foou/baru/u”, [1]),可以通过带命名空间名称的变量名来获取其命名空间下的变量。
模型持久化
当我们使用 tensorflow 训练神经网络的时候,模型持久化对于我们的训练有很重要的作用。
如果我们的神经网络比较复杂,训练数据比较多,那么我们的模型训练就会耗时很长,如果在训练过程中出现某些不可预计的错误,导致我们的训练意外终止,那么我们将会前功尽弃。为了避免这个问题,我们就可以通过模型持久化(保存为CKPT格式)来暂存我们训练过程中的临时数据。
如果我们训练的模型需要提供给用户做离线的预测,那么我们只需要前向传播的过程,只需得到预测值就可以了,这个时候我们就可以通过模型持久化(保存为PB格式)只保存前向传播中需要的变量并将变量的值固定下来,这个时候只需用户提供一个输入,我们就可以通过模型得到一个输出给用户。
持久化代码实现
TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是tf.train.Saver
类。以下代码给出了保存TensorFlow计算图的方法。
保存
import tensorflow as tf#声明两个变量并计算他们的和v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")result = v1 + v2init_op = tf.global_initialize_variables()#声明tf.train.Saver类用于保存模型saver = tf.train.Saver()with tf.Session() as sess: sess.run(init_op) #将模型保存到/path/to/model/model.ckpt文件。 saver.save(sess, "/path/to/model/model.ckpt")
上面的代码实现了持久化一个简单的TensorFlow模型的功能。在这个段代码中,通过saver.save函数将TensorFlow模型保存到/path/to/model/model.ckpt文件中。TensorFlow模型一般会保存在后缀为.ckpt的文件中。同时在这个文件目录下会出现三个文件。这是因为TensorFlow会将计算图的结构和图上参数取值分开保存。
- checkpoint文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
- model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构 。TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。
- model.ckpt文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader类来查看model.ckpt文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。
读取
# Part2: 加载TensorFlow模型的方法 import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" print(sess.run(result)) # [ 3.] # Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图 import tensorflow as tf saver = tf.train.import_meta_graph("Model/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]
在上面给出的程序中,默认保存和加载了TensorFlow计算图上定义的全部变量。有时可能只需要保存或者加载部分变量。 比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。为了保存或者加载部分变量,在声明tf.train.Saver
类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。
tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{“已保存的变量的名称name”: 重命名变量名},saver = tf.train.Saver({“v1”:u1, “v2”: u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。
# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名 import tensorflow as tf # 声明的变量名称name与已保存的模型中的变量名称name不一致 u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") result = u1 + u2 # 若直接生命Saver类对象,会报错变量找不到 # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 saver = tf.train.Saver({"v1": u1, "v2": u2}) with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") print(sess.run(result)) # [ 3.]
使用tf. train. Saver 会保存运行TensorFlow 程序所需要的全部信息,然而有时并不需要某些信息。比如在测试或者离线预测时,只需要知道如何从神经网络的输入层经过前向传播计算得到输出层即可,而不需要类似于变量初始化、模型保存等辅助节点的信息。在迁移学习中,会遇到类似的情况。而且,将变量取值和计算图结构分成不同的文件存储有时候也不方便,于是TensorFlow 提供了convert_variables_to_constants
函数,通过这个函数可以将计算图中的变量及其取值通过常量的方式保存,这样整个TensorFlow 计算图可以统一存放在一个文件中。下面的程序提供了一个样例。
import tensorflow as tf from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 init_op = tf.global_variables_initializer()with tf.Session() as sess: sess.run(init op) #导出当前计算图的GraphDef 部分,只需要这一部分就可以完成从输入层到输出层的计算过程。 graph_def = tf.get_default_graph().as_graph_def() #将图中的变量及其取值转化为常量,同时将图中不必要的节点去掉。在5.4.2 小节中将会看 #到一些系统运算也会被转化为计算图中的节点(比如变量初始化操作)。如果只关心程序中定 #义的某些计算时,和这些计算无关的节点就没有必要导出并保存了。在下面一行代码中,最 #后一个参数[ 'add'] 给出了需要保存的节点名称。add 节点是上面定义的两个变量相加的 #操作。注意这里给出的是计算节点的名称,所以没有后面的O output_graph_def = graph_util.convert_variables_to_constants(sess , graph_def, ['add']) #将导出的模型存入文件。 with tf.gfile.GFile("/path/to/model/combined_model.pb" , "wb") as f: f.write(output_graph_def.SerializeToString())
通过下面的程序可以直接计算定义的加法运算的结果。当只需要得到计算图中某个节点的取值时,这提供了一个更加方便的方法。(这个可以用来实现迁移学习)
import tensorflow as tf from tensorflow.python.platform import gfile with tf.Session() as sess: model_filename = "Model/combined_model.pb" #读取保存的模型文件,并将文件解析成对应的GraphDef Protocol Buffer。 with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=["add:0"]) print(sess.run(result)) # [array([ 3.], dtype=float32)]
持久化原理及数据格式
TensorFlow 是一个通过图的形式来表述计算的编程系统,TensorFlow 程序中的所有计算都会被表达为计算图上的节点。TensorFlow 通过元图( MetaGraph )来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow 中元图是由MetaGraphDef Protocol Buffer 定义的。MetaGraphDef 中的内容就构成了TensorFlow 持久化时的第一个文件。以下代码给出了MetaGraphDef 类型的定义。
message MetaGraphDef { MetaInfoDef meta_info_def = 1; GraphDef graph_def = 2; SaverDef saver_def = 3; map<string, CollectionDef> collection_def = 4; map<string, SignatureDef> signature_def = 5;}
从上面的代码中可以看到,元图中主要记录了5 类信息。保存MetaGraphDef 信息的文件默认以meta 为后缀名,文件model. ckpt. meta 中存储的就是元图的数据。为了方便调试, TensorFlow 提供了export_ meta _graph 函数,这个函数支持以Json 格式导出MetaGraphDef Protocol Buffer 。以下代码展示了如何使用这个函数。
import tensorflow as tf#定义变量相加的计算。v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1" )v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2" )result1 = v1 + v2saver = tf.train.Saver()#通过export_meta_graph 函数导出TensorFlow 计算图的元图,并保存为json 格式。saver.export_meta_graph("/path/to/model.ckpt.meda.json", as_text=True)
通过上面给出的代码,可以将计算图元图以Json 的格式导出并存储在model.ckpt.meta.json 文件中。下文将结合model.ckpt.meta.json 文件具体介绍TensorFlow 元图中存储的信息。
meta_info_def属性
meta_info_def 属性是通过MetalnfoDef 定义的,它记录了TensorFlow 计算图中的元数据以及TensorFlow 程序中所有使用到的运算方法的信息。下面是MetalnfoDef Protocol Buffer 的定义:
message MetaInfoDef { string meta_graph_version = 1; #计算图的版本号 OpList stripped_op_list = 2; google.protobuf.Any any_info = 3; repeated string_tags = 4; #用户指定的一些标签}
stripped_op_list 属性记录了TensorFlow 计算图上使用到的所有运算方法的信息。注意stripped_op_list属性保存的是TensorFlow 运算方法的信息,所以如果某一个运算在TensorFlow 计算图中出现了多次,那么在stripped_op_list也只会出现一次。
stripped_op_list 属性的类型是OpList。OpList 类型是一个OpDef 类型
的列表,以下代码给出了OpDef 类型的定义:
message OpDef { string name = 1; repeated ArgDef input_arg = 2; repeated ArgDef output arg = 3 ; repeated AttrDef attr = 4 ; string summary = 5; string description = 6; OpDeprecation deprecation = 8; boo1 is_commutative = 18; bool is_aggregate = 16 bool is_stateful = 17; bool allows_uninitialized_input = 19 ;
OpDef 类型中前四个属性定义了一个运算最核心的信息。OpDef 中的第一个属性name定义了运算的名称,这也是一个运算唯一的标识符。在TensorFlow 计算图元图的其他属性中,比如下面将要介绍的GraphDef 属性,将通过运算名称来引用不同的运算。OpDef 的第二和第三个属性为input_arg 和output_arg,它们定义了运算的输入和输出。因为输入输出都可以有多个,所以这两个属性都是列表(repeated) 。第四个属性attr
给出了其他的运算参数信息。
graph_def属性
graph_def
属性主要记录了TensorFlow 计算图上的节点信息。TensorFlow 计算图的每一个节点对应了TensorFlow 程序中的一个运算。因为在meta _info_def
属性中己经包含了所有运算的具体信息,所以graph_def
属性只关注运算的连接结构。graph_def
属性是通过GraphDef Protocol Buffer 定义的, GraphDef 主要包含了一个NodeDef 类型的列表。以下代码给出了GraphDef 和NodeDef 类型中包含的信息:
message GraphDef { repeated NodeDef node = 1; VersionDef versions = 4 ;} ;message NodeDef { string name = 1; string op = 2; repeated string input = 3; string device = 4; map<string, AttrValue> attr = 5;}
GraphDef 中的versions 属性比较简单,它主要存储了TensorFlow 的版本号。GraphDef的主要信息都存在node属性中,它记录了TensorFlow 计算图上所有的节点信息。
- 和其他属性类似,NodeDef 类型中有一个名称属性name ,它是一个节点的唯一标识符。在TensorFlow 程序中可以通过节点的名称来获取相应的节点。
- NodeDef 类型中的op 属性给出了该节点使用的TensorFlow 运算方法的名称,通过这个名称可以在TensorFlow 计算图元图的meta info def 属性中找到该运算的具体信息。
- NodeDef 类型中的input 属性是一个字符串列表,它定义了运算的输入。input 属性中每个字符串的取值格式为node:src_output ,其中node 部分给出了一个节点的名称, src _output部分表明了这个输入是指定节点的第几个输出。当src_output 为0 时,可以省略: src_output这个部分。比如node:0 表示名称为node 的节点的第一个输出,它也可以被记为node 。
- NodeDef 类型中的device 属性指定了处理这个运算的设备。运行TensorFlow 运算的设备可以是本地机器的CPU 或者GPU ,也可以是一台远程的机器CPU 或者GPU 。
saver_def 属性
saver_def
属性中记录了持久化模型时需要用到的一些参数,比如保存到文件的文件名、保存操作和加载操作的名称以及保存频率、清理历史记录等。saver_def
属性的类型为SaverDef,其定义如下。
message SaverDef { string filename_tensor_name = 1 ; string save_tensor_name = 2; string restore_op_name = 3; int32 max_to_keep = 4; bool sharded = 5; float keep_checkpoint_every_n_hours = 6; enum CheckpointFormatVersion { LEGACY = 0; V1 = 1; V2 = 2; } CheckpointFormatVersion version = 7;}
filename_tensor_name 属性给出了保存文件名的张量名称,这个张量就是节点save/Const 的第一个输出。save_tensor_name 属性给出了持久化TensorFlow 模型的运算所对应的节点名称。从上面的文件中可以看出,这个节点就是在graph_def 属性中给出的save/control_dependency 节点。和持久化TensorFlow 模型运算对应的是加载TensorFlow 模型的运算,这个运算的名称由restore_op_name 属性指定。max_to_keep 属性和keep_checkpoint_every_n_hours 属性设定了tf.train.Saver 类清理之前保存的模型的策略。比如当max_to_keep 为5 的时候,在第六次调用saver.save 时,第一次保存的模型就会被自动删除。通过设置keep_checkpoint_every_n_hours ,每n 小时可以在max_t_keep 的基础上多保存一个模型。
collection_def属性
在TensorFlow 的计算图( tf. Graph) 中可以维护不同集合, 而维护这些集合的底层实现就是通过collection_def 这个属性。collection_def 属性是一个从集合名称到集合内容的映射,其中集合名称为字符串,而集合内容为CollectionDef Protocol Buffer 。以下代码给出了CollectionDef 类型的定义。
message CollectionDef { message NodeList { repeated string value = 1; } message BytesList { repeated bytes value = 1; } message Int64List { repeated int64 va1ue = 1 [packed = true]; } message FloatList { repeated f1oat value = 1 [packed = true]; } message AnyList { repeated google.protobuf.Any value = 1; } oneof kind { NodeList node_list = 1; BytesList bytes_list = 2; Int64List int64_list = 3 ; FloatList f1oat_list = 4; AnyList any_list = 5; }}
通过上面的定义可以看出, TensorFlow 计算图上的集合主要可以维护4 类不同的集合。NodeList 用于维护计算图上节点的集合。BytesList 可以维护字符串或者系列化之后的Procotol Buffer 的集合。比如张量是通过Protocol Buffer 表示的, 而张量的集合是通过BytesList 维护的。
mode1.ckpt 文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSliceProtocol Buffer 定义的。SavedSlice 类型中保存了变量的名称、当前片段的信息以及变量取值。TensroFlow 提供了tf.train.NewCheckpoin tReader 类来查看mode1.ckpt 文件中保存的变量信息。
最后一个文件的名字是固定的,叫checkpoint。这个文件是tf.train.Saver 类自动生成且自动维护的。在checkpoint 文件中维护了由一个tf.train.Saver 类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow 模型文件被删除时,这个模型所对应的文件名也会从checkpoint 文件中删除。checkpoint 中内容的格式为CheckpointState Protocol Buffer。
- TensorFlow 入门 3 ——变量管理和模型持久化
- Tensorflow学习笔记(6)——变量管理和模型持久化
- TensorFlow模型的保存和持久化
- Tensorflow 模型持久化
- tensorflow--模型持久化
- Tensorflow中的模型持久化
- tensorflow 模型的持久化
- Tensorflow基础:模型持久化
- 5.4 TensorFlow模型持久化
- 【TensorFlow】模型持久化tf.train.Saver—上(八)
- 【TensorFlow】模型持久化tf.train.Saver—下(九)
- TensorFlow MNIST LeNet 模型持久化
- Tensorflow模型持久化与恢复
- Tensorflow模型持久化的代码实现
- Tensorflow深度学习入门——优化训练MNIST数据和调用训练模型识别图片
- MNIST 数字识别和数据持久化--step by step 入门TensorFlow(三)
- 【TensorFlow】MNIST(代码重构+模型持久化)
- 持久化框架——编程模型
- java.lang.UnsupportedClassVersionError
- Centos7安装Mysql
- C语言 fread()与fwrite()函数说明与示例
- 11个Java 开源 socket框架
- ContextLoaderListener加载过程(最详细版)
- TensorFlow 入门 3 ——变量管理和模型持久化
- JDK动态代理笔记
- dubbo高级篇-13 Dubbo服务集群-集群容错模式
- java程序在JVM中的运行顺序:
- JavaScript语言基础
- 【NOIP2017提高组正式赛】Day1T3逛公园
- 1034. 有理数四则运算(20)
- 第一次考试
- HDU 2897-邂逅明下 博弈论初步 巴什博弈