fine-tuning 预训练的模型文件
来源:互联网 发布:vb九九乘法表左上三角 编辑:程序博客网 时间:2024/06/07 16:29
首先介绍tensorflow 版本的,当你有完整的训练好的tensorflow 模型时你的文件夹里边会出现四个文件
1、checkpoint 文件,这个文件当中存放的时预训练好的模型地址
2、model.ckpt.meta 文件,这个文件当中存放的是你预训练好的模型的grah,解析这个文件你能得到当初保存模型时所存储的变量的名称和形状
3、model.ckpt.index 文件和model.ckpt.data-00000-of-00001 之类的文件(具体名称可能有出入)。之所以将这两个文件一起介绍
是因为这两个文件在重载 checkpoint文件读取其中的参数值时 这两个文件缺一不可。
import tensorflow as tfnew_saver = tf.train.import_meta_graph("./model.ckpt-1000.meta")#此处的‘model.ckpt-1000.meta’对应特定的*.ckpt.meta文件print "ModelV construct"all_vars = tf.trainable_variables()config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith tf.Session(config = config) as sess: for v in all_vars: print v.name, ':', v.shape, sess.run(tf.shape(v))#以下的部分则是依据从graph中解析出的变量对参数值进行重载 new_saver.restore(sess, tf.train.latest_checkpoint('./')) for v in all_vars: print v.name,v.eval(sess)下面贴出session上半部分的print内容
依据以上的文件所打印的内容可以得知预训练模型所保存的变量名称
然后依据所得知的变量名,构建自己的新的graph,(保证自己新的graph中变量的名称和预训练模型的变量名称一致)
然后构建一个saver将新的graph中变量存到saver中,用这个saver 去restore预训练模型。
下面贴出一个根据上边的预训练模型中的参数给新的graph的部分变量进行初始化的code
import tensorflow as tflist_restore = []with tf.name_scope('conv1'): v1 = tf.Variable(tf.random_normal([16]), name="biases") list_restore.append(v1)with tf.name_scope('conv2'): v1 = tf.Variable(tf.random_normal([16]), name = 'biases') list_restore.append(v1)with tf.name_scope('softmax_linear'): v1 = tf.Variable(tf.random_normal([2]), name = 'biases')init = tf.global_variables_initializer()all_vars = tf.trainable_variables()new_saver = tf.train.Saver(list_restore)config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith tf.Session(config = config) as sess: sess.run(init) new_saver.restore(sess, tf.train.latest_checkpoint('./')) for v in list_restore: print v.name,v.eval(sess) print('*'*18) for var in all_vars: print var.name, var.shape, var.eval(sess)
我新构建的graph中前两个v1变量的name与预训练模型保存参数name相同,并且我将这两个变量放在一个list中构建saver。
再用这个saver去restore预训练模型;注意,最后一个name为“softmax_linear/biases:0”的变量v1虽然与预训练模型的变量名(name)一样
但是我并没有将其加入saver中,所以最终我得到的三个变量是前两个v1初始化是由预训练模型的参数对其进行初始化,
而最后一个v1变量初始化是由其变量原先构建时指定的初始化方式进行初始化
********************************************分割线*********************************************************************************
下面介绍如何对预训练好的caffemodel文件进行新网络的初始化
直接附一个github抓的代码#来自于Drsleep先生的deeplabv1版本的tensorflow实现
"""Extract parameters of the DeepLab-LargeFOV model from the provided .caffemodel file. This scripts extracts and saves the network skeleton with names and shape of the parameters, as well as all the corresponding weights.To run the script, PyCaffe should be installed."""from __future__ import print_functionimport argparseimport osimport sysfrom six.moves import cPickledef get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Extract model parameters of DeepLab-LargeFOV from the provided .caffemodel.") parser.add_argument("--caffemodel", type=str, default = "train2_iter_8000.caffemodel", help="Caffemodel from which the parameters will be extracted.") parser.add_argument("--output_dir", type=str, default="/home/yanyu/deeplab-lfov/util/", help="Whether to store the network skeleton and weights.") parser.add_argument("--pycaffe_path", type=str, default="/home/yanyu/caffe/python", help="Path to PyCaffe (e.g., 'CAFFE_ROOT/python').") return parser.parse_args()def main(): """Extract and save network skeleton with the corresponding weights. Raises: ImportError: PyCaffe module is not found.""" args = get_arguments() sys.path.append(args.pycaffe_path) try: import caffe except ImportError: raise # Load net definition. net = caffe.Net('deploy.prototxt', args.caffemodel, caffe.TEST) #关键在这一句,对于 caffemodel进行解析, 此处的*.prototxt文件是你自己新构建的网络 # Check the existence of output_dir. if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # Net skeleton with parameters names and shapes. # In TF, the filter shape is as follows: [ks, ks, input_channels, output_channels], # while in Caffe it looks like this: [output_channels, input_channels, ks, ks]. net_skeleton = list() for name, item in net.params.iteritems(): net_skeleton.append([name + '/w', item[0].data.shape[::-1]]) # See the explanataion on filter formats above. net_skeleton.append([name + '/b', item[1].data.shape]) with open(os.path.join(args.output_dir, 'net_skeleton_yen.ckpt'), 'wb') as f: cPickle.dump(net_skeleton, f, protocol=cPickle.HIGHEST_PROTOCOL) # Net weights. net_weights = dict() for name, item in net.params.iteritems(): net_weights[name + '/w'] = item[0].data.transpose(2, 3, 1, 0) # See the explanation on filter formats above. net_weights[name + '/b'] = item[1].data with open(os.path.join(args.output_dir,'net_weights_yen.ckpt'), 'wb') as f: cPickle.dump(net_weights, f, protocol=cPickle.HIGHEST_PROTOCOL) del net, net_skeleton, net_weightsif __name__ == '__main__': main()
这个代码直接将一个新的网络用预训练好的caffemodel文件对其进行初始化,而返回的两个文件一个是list<保存的是参数的name和shape>保存文件名为‘net_skeleton_yen.ckpt’
另外一个是dirt<保存的时name 和相应的参数值> 保存文件名为‘net_weights_yen.ckpt’
需要注意的是,你新构建的prototxt文件的layer中的变量(weigts 和biases)的name 只有和生成caffemodel文件的所定义的一致时
- fine-tuning 预训练的模型文件
- fine-tuning:利用已有模型训练其他数据集
- fine-tuning:利用已有模型训练其他数据集
- fine-tuning:利用已有模型训练其他数据集
- fine-tuning的二三事
- 关于fine-tuning:利用已有模型训练其他数据集
- Windows下caffe用fine-tuning训练好的caffemodel来进行图像分类
- fine-tuning
- pytorch学习笔记(十一):fine-tune 预训练的模型
- Keras入门-预训练模型fine-tune(ResNet)
- 迁移学习与fine-tuning的区别
- caffe中fine-tuning的那些事
- Tensorflow Fine-Tuning 的一些说明
- MXNet的预训练:fine-tune.py源码详解
- Caffe fine-tuning 学习
- YOLOv2如何fine-tuning?
- 深度学习fine-tuning
- YOLOv2如何fine-tuning?
- springboot跑批
- mysql更换版本后,数据data文件夹导入
- Printf函数和cout函数参数执行顺序以及自增与自减
- spring-boot配置(一):@Configuration,@ConfigurationProperties和application.yml
- A4纸网页打印
- fine-tuning 预训练的模型文件
- html中用div代替textarea实现输入框高度随输入内容变化
- 适用于 Windows 的自定义脚本扩展
- 数据结构基础JAVA 实现表、栈和队列
- hdu6239 Interview 期望+拉格朗日插值法|生成函数 推公式
- sass或scss文件如何避免被编译?
- js根据时间戳换算过去间隔
- 桥接模式
- jQuery教程 2 语法