新手上手Tensorflow之手写数字识别应用(2)
来源:互联网 发布:攀枝花学院知乎 编辑:程序博客网 时间:2024/05/18 11:48
本系列为应用TensorFlow实现手写数字识别应用的全过程的代码实现及细节讨论。按照实现流程,分为如下几部分:
1. 模型训练并保存模型
2. 通过鼠标输入数字并保存
2. 图像预处理
4. 读入模型对输入的图片进行识别
本文重点讨论模型的保存以及读入问题。
关于TensorFlow模型训练的部分,算法实现部分的论文、博客以及源码很多很多,相信大家也看了很多了,这里就不过多讨论。重点是,我们如何把我们训练的模型保存以及如何读入的问题。
训练完模型后,我们会得到模型参数的训练结果。如果我们想之后分享这个结果或者用来进行测试,就要保存这个结果了。TensorFlow提供了Saver类来保存和恢复(save/restore)模型参数。
1. Saver类保存和恢复参数的用法
首先来看一下官方给出的demo:
#Saving variables# Create some variables.import osos.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。import shutilcheckpoints_dir = './checkpoint1940/'if os.path.exists(checkpoints_dir): shutil.rmtree(checkpoints_dir)os.makedirs(checkpoints_dir)checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')import tensorflow as tfv1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)inc_v1 = v1.assign(v1+1)dec_v2 = v2.assign(v2-1)# Add an op to initialize the variables.init_op = tf.global_variables_initializer()# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, initialize the variables, do some work, and save the# variables to disk.with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess, checkpoint_prefix) print("Model saved in file: %s" % save_path)
运行结果:
import osos.environ['TF_CPP_MIN_LOG_LEVEL']='2' #屏蔽乱七八糟的输出信息--强迫症患者。。。import shutilcheckpoints_dir = './checkpoint1940/'#if os.path.exists(checkpoints_dir):# shutil.rmtree(checkpoints_dir)#os.makedirs(checkpoints_dir)checkpoint_prefix = os.path.join(checkpoints_dir, 'model.ckpt')import tensorflow as tftf.reset_default_graph() #Clears the default graph stack and resets the global default graph.# Create some variables.v1 = tf.get_variable("v1", shape=[3])v2 = tf.get_variable("v2", shape=[5])#Note that when you restore variables from a file you do not have to initialize them beforehand# Add ops to save and restore all the variables.saver = tf.train.Saver()# Later, launch the model, use the saver to restore variables from disk, and# do some work with the model.with tf.Session() as sess: # Restore variables from disk. saver.restore(sess, checkpoint_prefix) print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval())
运行结果:
2. save/restore过程的技术细节
(1)checkpoint 文件
TensorFlow的Saver类是通过操作checkpoint文件来实现对变量(Variable)的存储和恢复。checkpoint文件是二进制的文件,存放着按照固定格式存储的“变量名-Tensor值”map对。一般来说,checkpoint文件有四种:
其中,checkpoint文件可以直接用记事本打开,里面存放的是最新模型的path和所有模型的path;
.meta stores the graph structure, .data stores the values of each variable in the graph, .index identifies the checkpiont. So in the example above: import_meta_graph uses the .meta, and saver.restore uses the .data and .index
(2)graph structure数据结构
当我们用默认的方式saver = tf.train.Saver()创建saver对象的时候,saver将持有graph里的所有的变量。那当我们分开save和restore的时候,就会出现集中情况:
- restore时候的saver对象持有的variable是在save的时候的saver持有variable的一个子集:也就是训练时候的变量我们在测试的时候不一定都用,这时候我们就可以选取其子集创建使用,这种情况是没问题的;
- restore时候的saver对象持有的variable在save的时候saver并没有持有。也就是说,我们在测试的时候定义了一个新的变量,这个变量在save的时候没有出现,那么这时候如果restore,因为保存的变量中没有这个新的变量,所以就会报错。例如,我们在上面的restore的python程序中,在v2变量下面加一个v3变量,v2 = tf.get_variable(“v2”, shape=[5]),运行一下,就会出现 NotFoundError 错误:
NotFoundError (see above for traceback): Key v3 not found in checkpoint [[Node: save/RestoreV2_2 = RestoreV2[dtypes=[DT_FLOAT],_device=”/job:localhost/replica:0/task:0/device:CPU:0”](_arg_save/Const_0_0, save/RestoreV2_2/tensor_names, save/RestoreV2_2/shape_and_slices)]]
(3)控制Saver的数据结构
根据上面的情况,我们只要确保restore时候的变量在save时都出现过就好了。但这样会给编程造成很大的不变。因为我们在测试的时候,很有可能会创建一些新的变量。针对这种情况,TensorFlow有两种方式可以解决:
- 创建Saver的时候,定义要保存的变量;这样我们在restore的时候,也一样定义要restore的变量,就好了;
v1 = tf.Variable(..., name='v1')v2 = tf.Variable(..., name='v2')# Pass the variables as a dict:saver = tf.train.Saver({'v1': v1, 'v2': v2})# Or pass them as a list.saver = tf.train.Saver([v1, v2])# Passing a list is equivalent to passing a dict with the variable op names# as keys:saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
- 但是有时候graph的结构比较复杂,要保存的变量很多,要一一对应还是很麻烦的。怎么办?采用tf.train.import_meta_graph()方法
#Create a saver.saver = tf.train.Saver(...variables...)#Remember the training_op we want to run by adding it to a collection.tf.add_to_collection('train_op', train_op)sess = tf.Session()for step in xrange(1000000): sess.run(train_op) if step % 1000 == 0: # Saves checkpoint, which by default also exports a meta_graph # named 'my-model-global_step.meta'. saver.save(sess, 'my-model', global_step=step)
在save的时候我们保存我们想要保存的变量,当然可以直接默认保存全部;在restore的时候,我们先导入保存的模型的数据结构,就可以了。
with tf.Session() as sess: new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') new_saver.restore(sess, 'my-save-dir/my-model-10000') # tf.get_collection() returns a list. In this example we only want the # first one. train_op = tf.get_collection('train_op')[0] for step in xrange(1000000): sess.run(train_op)
Reference
- TensorFlow, why there are 3 files after saving the model?:
https://stackoverflow.com/questions/41265035/tensorflow-why-there-are-3-files-after-saving-the-model
- 新手上手Tensorflow之手写数字识别应用(2)
- 新手上手Tensorflow之手写数字识别应用(1)
- 新手上手Tensorflow之手写数字识别应用(3)
- tensorflow识别手写数字(2)
- tensorflow识别手写数字
- Tensorflow手写数字识别
- TensorFlow实战(一)手写数字识别
- TensorFlow学习笔记之源码分析(2)----手写数字识别mnist example
- TensorFlow在MNIST中的应用 识别手写数字(OpenCV+TensorFlow+CNN)
- tensorflow-mnist手写数字识别
- TensorFlow实现识别手写数字
- TensorFlow实时识别手写数字(数字通过鼠标输入)
- TensorFlow学习笔记(一):手写数字识别之softmax回归
- TensorFlow实战:手写数字识别之K近邻
- Tensorflow实战之用softmax Regression识别手写数字
- TensorFlow学习笔记(四):手写数字识别之LSTM网络
- Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
- TensorFlow MNIST 手写数字识别之过拟合
- 今天想起论语十则,记下来勉励自己
- 解决iPhoneX push过程中tabbar上移的问题
- freemarker导出word文档对图片拉伸或拉长的处理
- 云星数据---Scala实战系列(精品版)】:Scala入门教程058-Scala实战源码-Scala 正则 Regex
- kali忘记密码
- 新手上手Tensorflow之手写数字识别应用(2)
- Openstack api 学习文档 & restclient使用文档
- tips_for_sequence_1
- 项目反思--学校大学生活动中心预约系统--[Json、Json字符串、Json数组、对象、List的互相转换]
- 【工业互联网】安筱鹏:从工业云到工业互联网平台演进的五个阶段
- Learning multiple layers of representation理解
- 大型网站技术架构(八)——网站的安全架构
- oracle学习笔记:分析函数
- Servlet乱码处理方式/转发重定向