新手上手Tensorflow之手写数字识别应用(2)

来源:互联网 发布:攀枝花学院知乎 编辑:程序博客网 时间:2024/05/18 11:48

本系列为应用TensorFlow实现手写数字识别应用的全过程的代码实现及细节讨论。按照实现流程,分为如下几部分:
1. 模型训练并保存模型
2. 通过鼠标输入数字并保存
2. 图像预处理
4. 读入模型对输入的图片进行识别
本文重点讨论模型的保存以及读入问题。
关于TensorFlow模型训练的部分,算法实现部分的论文、博客以及源码很多很多,相信大家也看了很多了,这里就不过多讨论。重点是,我们如何把我们训练的模型保存以及如何读入的问题。
训练完模型后,我们会得到模型参数的训练结果。如果我们想之后分享这个结果或者用来进行测试,就要保存这个结果了。TensorFlow提供了Saver类来保存和恢复(save/restore)模型参数。
1. Saver类保存和恢复参数的用法
首先来看一下官方给出的demo:

#Saving variables# Create some variables.import osos.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。import shutilcheckpoints_dir = './checkpoint1940/'if os.path.exists(checkpoints_dir):    shutil.rmtree(checkpoints_dir)os.makedirs(checkpoints_dir)checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')import tensorflow as tfv1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess:  sess.run(init_op)  # Do some work with the model.  inc_v1.op.run()  dec_v2.op.run()  # Save the variables to disk.  save_path = saver.save(sess, checkpoint_prefix)  print("Model saved in file: %s" % save_path)

运行结果:
saveresult

import osos.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。import shutilcheckpoints_dir = './checkpoint1940/'#if os.path.exists(checkpoints_dir):#    shutil.rmtree(checkpoints_dir)#os.makedirs(checkpoints_dir)checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')import tensorflow as tftf.reset_default_graph() #Clears the default graph stack and resets the global default graph.# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])#Note that when you restore variables from a file you do not have to initialize them beforehand# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess:  # Restore variables from disk.  saver.restore(sess, checkpoint_prefix)  print("Model restored.")  # Check the values of the variables  print("v1 : %s" % v1.eval())  print("v2 : %s" % v2.eval())

运行结果:
restore result
2. save/restore过程的技术细节
(1)checkpoint 文件
TensorFlow的Saver类是通过操作checkpoint文件来实现对变量(Variable)的存储和恢复。checkpoint文件是二进制的文件,存放着按照固定格式存储的“变量名-Tensor值”map对。一般来说,checkpoint文件有四种:
checkpointfile
其中,checkpoint文件可以直接用记事本打开,里面存放的是最新模型的path和所有模型的path;

.meta stores the graph structure, .data stores the values of each variable in the graph, .index identifies the checkpiont. So in the example above: import_meta_graph uses the .meta, and saver.restore uses the .data and .index

(2)graph structure数据结构
当我们用默认的方式saver = tf.train.Saver()创建saver对象的时候,saver将持有graph里的所有的变量。那当我们分开save和restore的时候,就会出现集中情况:

  • restore时候的saver对象持有的variable是在save的时候的saver持有variable的一个子集:也就是训练时候的变量我们在测试的时候不一定都用,这时候我们就可以选取其子集创建使用,这种情况是没问题的;
  • restore时候的saver对象持有的variable在save的时候saver并没有持有。也就是说,我们在测试的时候定义了一个新的变量,这个变量在save的时候没有出现,那么这时候如果restore,因为保存的变量中没有这个新的变量,所以就会报错。例如,我们在上面的restore的python程序中,在v2变量下面加一个v3变量,v2 = tf.get_variable(“v2”, shape=[5]),运行一下,就会出现 NotFoundError 错误:
    NotFoundError (see above for traceback): Key v3 not found in checkpoint [[Node: save/RestoreV2_2 = RestoreV2[dtypes=[DT_FLOAT],_device=”/job:localhost/replica:0/task:0/device:CPU:0”](_arg_save/Const_0_0, save/RestoreV2_2/tensor_names, save/RestoreV2_2/shape_and_slices)]]

(3)控制Saver的数据结构
根据上面的情况,我们只要确保restore时候的变量在save时都出现过就好了。但这样会给编程造成很大的不变。因为我们在测试的时候,很有可能会创建一些新的变量。针对这种情况,TensorFlow有两种方式可以解决:
- 创建Saver的时候,定义要保存的变量;这样我们在restore的时候,也一样定义要restore的变量,就好了;

v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')# Pass the variables as a dict:saver = tf.train.Saver({'v1': v1, 'v2': v2})# Or pass them as a list.saver = tf.train.Saver([v1, v2])# Passing a list is equivalent to passing a dict with the variable op names# as keys:saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
  • 但是有时候graph的结构比较复杂,要保存的变量很多,要一一对应还是很麻烦的。怎么办?采用tf.train.import_meta_graph()方法
#Create a saver.saver = tf.train.Saver(...variables...)#Remember the training_op we want to run by adding it to a collection.tf.add_to_collection('train_op', train_op)sess = tf.Session()for step in xrange(1000000):    sess.run(train_op)    if step % 1000 == 0:        # Saves checkpoint, which by default also exports a meta_graph        # named 'my-model-global_step.meta'.        saver.save(sess, 'my-model', global_step=step)

在save的时候我们保存我们想要保存的变量,当然可以直接默认保存全部;在restore的时候,我们先导入保存的模型的数据结构,就可以了。

with tf.Session() as sess:  new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')  new_saver.restore(sess, 'my-save-dir/my-model-10000')  # tf.get_collection() returns a list. In this example we only want the  # first one.  train_op = tf.get_collection('train_op')[0]  for step in xrange(1000000):    sess.run(train_op)

Reference
- TensorFlow, why there are 3 files after saving the model?:
https://stackoverflow.com/questions/41265035/tensorflow-why-there-are-3-files-after-saving-the-model

阅读全文
0 0
原创粉丝点击