tensorflow预训练简单模型及权重文件复用初始化复杂模型
来源:互联网 发布:网络爬虫的目的和意义 编辑:程序博客网 时间:2024/06/06 03:54
笔者在学习YOLO网络的过程中,遇到了预训练问题。我在网上搜到的大部分相关问题都是在说如何利用之前的预训练权重文件做fineturning
问题:如何预训练简单网络,然后复用权重文件初始化复杂网络?
在yolo中,我需要预训练前20层网络,使其能在物体分类上达到不错的准确率
然后复用这个简单网络的权重文件初始化正式的YOLO网络,这是我遇到的实际问题
一开始,我想问题的突破口有三个:
1、npy文件保存与读取
本来我是用ckpt文件作为权重文件的载体,但是因为不太了解ckpt内部结构,结合网上一些介绍,就想着能不能另辟蹊径呢?
然而我想的是先把手上的资源利用起来,所以有了2和3
2、tf.get_variable()
受Salvador Dali的启发,要保留权重文件不过就是要保存变量,既要保存变量就得从变量的载体下手,要了解这个函数有哪些参数,返回值是什么,还有其他哪些类似功能的函数
3、tf.train.Saver()
保留权重文件的另一个方向就是在如何保留入手,这个应该是要和variable变量函数结合起来操作的吧,但一开始我并不知道如何下手
所以还是到官网,看文档,了解这函数的参数列表,返回值等等,再结合相关的博文,自己做一点测试。
利用 tf.get_variable()和 tf.train.Saver()解决上述问题
首先了解下tf.get_variable
其中collections这个参数有什么作用呢?
可以清楚地看到所有variable都会默认为global
这个参数可以将特定的variable设为local
collections=[tf.GraphKeys.LOCAL_VARIABLES]
为什么要这么做呢?接下来看完tf.train.Saver()的介绍,你就清楚了
Saver的参数var_list可以选择一个列表的变量,这些个变量的op name会在checkpoint files中作为keys
这说明什么?说明Saver可以任意选择要保存global变量或是local变量,或者两者都保存呢
而这个var_list可以是什么呢?
tf.global_variables()和tf.local_variables()都可以返回一个列表的global或local变量呢,两个加起来不就可以随意保存我想要的变量了吗
到这里是不是有思路了?
我可以将简单网络的变量,全部设为global,然后将用save(global)保存下来
然后restore回复杂网络不就可以了?
那么问题来了,真的有那么简单吗?
复杂网络的变量是全部设为global,还是部分设为global部分设为local,哪一部分设为local呢
让我做个测试验证一下吧
我要验证的是
当我把一部分global变量保存下来之后,在restore之前加入另一部分glocal变量,这个restore会失败吗?
或者在restore之前加入另一部分local变量,这个restore会成功吗?
下面是我用mnist数据集做的一个小测试
import tensorflow as tfimport osoutput='output'output_dir=os.path.join(os.path.abspath(output), 'weights')ckpt_file = os.path.join(output_dir, 'save.ckpt')from tensorflow.examples.tutorials.mnist import input_datamnist=input_data.read_data_sets("MNIST_data/",one_hot=True)sess=tf.InteractiveSession()x=tf.placeholder(tf.float32,[None,784])W=tf.get_variable("W",initializer=tf.zeros([784,10]),collections=[tf.GraphKeys.GLOBAL_VARIABLES])b=tf.get_variable("b",initializer=tf.zeros([10]),collections=[tf.GraphKeys.GLOBAL_VARIABLES])#b1=tf.get_variable("b1",initializer=tf.zeros([10]),collections=[tf.GraphKeys.LOCAL_VARIABLES]) y=tf.nn.softmax(tf.matmul(x,W)+b)y_=tf.placeholder(tf.float32,[None,10])cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)tf.global_variables_initializer().run()tf.local_variables_initializer().run()variable_to_restore = tf.global_variables()#+tf.local_variables()saver = tf.train.Saver(variable_to_restore, max_to_keep=None)is_train=False#True就训练,False为检测with tf.variable_scope("weights",reuse=True): if is_train: for i in range(1000): batch_xs,batch_ys=mnist.train.next_batch(100) train_step.run({x:batch_xs,y_:batch_ys}) if i % 100 ==0: print('Saving checkpoint file to: {}'.format(output_dir)) saver.save(sess,ckpt_file) correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})) else: model_file=tf.train.latest_checkpoint(output_dir) saver.restore(sess,model_file) print(sess.run(W))#这里是把之前保存的变量取出来观察一下 print(sess.run(b)) correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
结果是:
当我把GLOBAL的W和b保存下来之后,加了一个b1的local变量,再restore,没有报错
设is_train=False,测试结果也正常
当我把GLOBAL的W和b保存下来之后,加了一个b1的global变量,(collections=[tf.GraphKeys.GLOBAL_VARIABLES])
出现了一下错误
OK,至此,问题不就解决了吗?
先将简单网络的变量都设为global然后save下来
然后将复杂网络中较之简单网络多出来的变量统统设为local
然后用初始化为Saver(global)的对象1去restore,就可以达到利用简单网络额权重文件初始化复杂网络的目的啦
最后用初始化为Saver(global+local)的对象2去save,就可以把复杂网络的权重文件保存下来啦
- tensorflow预训练简单模型及权重文件复用初始化复杂模型
- tensorflow 加载预训练模型
- tensorflow 模型训练
- TensorFlow VGG-16 预训练模型
- Tensorflow加载预训练模型和保存模型
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
- 利用TensorFlow训练简单的二分类神经网络模型
- tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
- 【TensorFlow】神经网络模型训练及完整程序实例(五)
- tensorflow ssd mobilenet模型训练
- TensorFlow on Android:训练模型
- 神经网络模型中的权重参数初始化问题
- tensorflow之inception_v3模型的部分加载及权重的部分恢复(23)---《深度学习》
- TensorFlow使用C++加载使用训练好的模型,.cc文件代码实现的相关类及方法总结
- 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件
- MySQL存储过程详解
- 线程安全与可重入函数strtok_r()
- 20170810
- 8.11
- Metasploit的Docker安装及其Eternal Blue(永恒之蓝)渗透实现
- tensorflow预训练简单模型及权重文件复用初始化复杂模型
- [hihocoder1546]集合计数
- 2017.8.11
- 2017/8/11
- Action的创建与访问方式
- 919
- 20多条总结学完SymPy库
- Day13
- 拓扑排序 [HNOI2015]菜肴制作