深度学习(四)-Tensor Flow的变量创建、初始化、保存和加载

来源:互联网 发布:linux 查看死机日志 编辑:程序博客网 时间:2024/06/06 08:39

一.创建

# Create a variable.w = tf.Variable(<initial-value>, name=<optional-name>)

其中为初始值,既可以是随意值,也可以是常数
例:

# Create two variables.#标准差为0.35的784*200的任意数的矩阵weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),                      name="weights")#维度为200的全是零的向量biases = tf.Variable(tf.zeros([200]), name="biases")

二.初始化

变量的初始化必须在模型的其它操作运行之前先明确地完成。

1.最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。

# 定义一个操作对所有变量进行初始化.init_op = tf.initialize_all_variables()

之后运行模型时,首先要运行该操作,对所有变量进行初始化

2.用别的变量的初始值来定义另一个变量

# Create another variable with the same value as 'weights'.w2 = tf.Variable(weights.initialized_value(), name="w2")# Create another variable with twice the value of 'weights'w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

3.批量初始化一部分变量
tf.initialize_variables(var_list, name=’init’)
该操作可以初始化var_list中的所有变量

4.初始某个变量
tf.Variable.initialized_value()

# Initialize 'v' with a random tensor.v = tf.Variable(tf.truncated_normal([10, 40]))# Use `initialized_value` to guarantee that `v` has been initialized before its value is used to initialize `w`.# The random values are picked only once.w = tf.Variable(v.initialized_value() * 2.0)

三.保存和加载
1.用tf.train.Saver()创建一个Saver来管理模型中的所有变量,返回类型为Saver(仍然必须调用该save()方法来保存模型。将这些参数传递给构造函数不会自动保存变量)。
2.tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None)也可以保存变量。

此方法运行由构造函数添加的用于保存变量的ops。它需要一个已经启动的会话。要保存的变量也必须已初始化。

该方法返回新创建的检查点文件的路径。这个路径可以直接传递给一个调用restore()。

ARGS:

sess:用于保存变量的会话。
save_path:字符串。检查点文件名的路径。如果是保护程序 sharded,这是分片的检查点文件名的前缀。
global_step:如果提供了全局步号附加 save_path到创建检查点文件名。可选参数可以是Tensor,Tensor名称或整数。
latest_filename:包含最新检查点文件名列表的协议缓冲区文件的可选名称。保存在与检查点文件相同的目录中的文件由保存程序自动管理以跟踪最近的检查点。默认为“检查点”。
返回:

一个字符串:保存变量的路径。如果保护程序被分片,则该字符串以:’-nnnnn’结尾,其中’nnnnn’是创建的分片数。

# Create some variables.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# Add an op to initialize the variables.init_op = tf.initialize_all_variables()# Add ops to save and restore all the variables.#如果你不给tf.train.Saver()传入任何参数,那么saver将处理graph中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, save the# variables to disk.with tf.Session() as sess:  sess.run(init_op)  # Do some work with the model.  ..  # Save the variables to disk.  save_path = saver.save(sess, "/tmp/model.ckpt")  print "Model saved in file: ", save_path

用同一个Saver对象来恢复变量。当从文件中恢复变量时,不需要事先对它们做初始化

# Create some variables.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...#不用初始化# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, "/tmp/model.ckpt")  print "Model restored."  # Do some work with the model

注:
1.如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore()函数被运行时,它的值才会发生改变。
2.如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化op。(不太懂??)

# Create some variables.v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# Add ops to save and restore only 'v2' using the name "my_v2"saver = tf.train.Saver({"my_v2": v2})# Use the saver object normally after that.
阅读全文
0 0
原创粉丝点击