fine-tuning 预训练的模型文件

来源:互联网 发布:vb九九乘法表左上三角 编辑:程序博客网 时间:2024/06/07 16:29

首先介绍tensorflow 版本的,当你有完整的训练好的tensorflow 模型时你的文件夹里边会出现四个文件 

1、checkpoint 文件,这个文件当中存放的时预训练好的模型地址

2、model.ckpt.meta 文件,这个文件当中存放的是你预训练好的模型的grah,解析这个文件你能得到当初保存模型时所存储的变量的名称和形状

3、model.ckpt.index 文件和model.ckpt.data-00000-of-00001 之类的文件(具体名称可能有出入)。之所以将这两个文件一起介绍

  是因为这两个文件在重载 checkpoint文件读取其中的参数值时 这两个文件缺一不可。

import tensorflow as tfnew_saver = tf.train.import_meta_graph("./model.ckpt-1000.meta")#此处的‘model.ckpt-1000.meta’对应特定的*.ckpt.meta文件print "ModelV construct"all_vars = tf.trainable_variables()config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith tf.Session(config = config) as sess:    for v in all_vars:        print v.name, ':', v.shape, sess.run(tf.shape(v))#以下的部分则是依据从graph中解析出的变量对参数值进行重载    new_saver.restore(sess, tf.train.latest_checkpoint('./'))    for v in all_vars:        print v.name,v.eval(sess)
下面贴出session上半部分的print内容


依据以上的文件所打印的内容可以得知预训练模型所保存的变量名称

然后依据所得知的变量名,构建自己的新的graph,(保证自己新的graph中变量的名称和预训练模型的变量名称一致)

然后构建一个saver将新的graph中变量存到saver中,用这个saver 去restore预训练模型。

下面贴出一个根据上边的预训练模型中的参数给新的graph的部分变量进行初始化的code

import tensorflow as tflist_restore = []with tf.name_scope('conv1'):    v1 = tf.Variable(tf.random_normal([16]), name="biases")    list_restore.append(v1)with tf.name_scope('conv2'):    v1 = tf.Variable(tf.random_normal([16]), name = 'biases')    list_restore.append(v1)with tf.name_scope('softmax_linear'):    v1 = tf.Variable(tf.random_normal([2]), name = 'biases')init = tf.global_variables_initializer()all_vars = tf.trainable_variables()new_saver = tf.train.Saver(list_restore)config = tf.ConfigProto()config.gpu_options.allow_growth = Truewith tf.Session(config = config) as sess:    sess.run(init)    new_saver.restore(sess, tf.train.latest_checkpoint('./'))        for v in list_restore:        print v.name,v.eval(sess)    print('*'*18)        for var in all_vars:        print var.name, var.shape, var.eval(sess)

我新构建的graph中前两个v1变量的name与预训练模型保存参数name相同,并且我将这两个变量放在一个list中构建saver。

再用这个saver去restore预训练模型;注意,最后一个name为“softmax_linear/biases:0”的变量v1虽然与预训练模型的变量名(name)一样

但是我并没有将其加入saver中,所以最终我得到的三个变量是前两个v1初始化是由预训练模型的参数对其进行初始化,

而最后一个v1变量初始化是由其变量原先构建时指定的初始化方式进行初始化


********************************************分割线*********************************************************************************

下面介绍如何对预训练好的caffemodel文件进行新网络的初始化

直接附一个github抓的代码#来自于Drsleep先生的deeplabv1版本的tensorflow实现


"""Extract parameters of the DeepLab-LargeFOV model   from the provided .caffemodel file.   This scripts extracts and saves the network skeleton with names and shape of the parameters, as well as all the corresponding weights.To run the script, PyCaffe should be installed."""from __future__ import print_functionimport argparseimport osimport sysfrom six.moves import cPickledef get_arguments():    """Parse all the arguments provided from the CLI.        Returns:      A list of parsed arguments.    """    parser = argparse.ArgumentParser(description="Extract model parameters of DeepLab-LargeFOV from the provided .caffemodel.")    parser.add_argument("--caffemodel", type=str, default = "train2_iter_8000.caffemodel",                        help="Caffemodel from which the parameters will be extracted.")    parser.add_argument("--output_dir", type=str, default="/home/yanyu/deeplab-lfov/util/",                        help="Whether to store the network skeleton and weights.")    parser.add_argument("--pycaffe_path", type=str, default="/home/yanyu/caffe/python",                        help="Path to PyCaffe (e.g., 'CAFFE_ROOT/python').")    return parser.parse_args()def main():    """Extract and save network skeleton with the corresponding weights.        Raises:      ImportError: PyCaffe module is not found."""    args = get_arguments()    sys.path.append(args.pycaffe_path)    try:        import caffe    except ImportError:        raise    # Load net definition.    net = caffe.Net('deploy.prototxt', args.caffemodel, caffe.TEST) #关键在这一句,对于 caffemodel进行解析, 此处的*.prototxt文件是你自己新构建的网络        # Check the existence of output_dir.    if not os.path.exists(args.output_dir):        os.makedirs(args.output_dir)        # Net skeleton with parameters names and shapes.    # In TF, the filter shape is as follows: [ks, ks, input_channels, output_channels],    # while in Caffe it looks like this: [output_channels, input_channels, ks, ks].    net_skeleton = list()     for name, item in net.params.iteritems():        net_skeleton.append([name + '/w', item[0].data.shape[::-1]]) # See the explanataion on filter formats above.        net_skeleton.append([name + '/b', item[1].data.shape])        with open(os.path.join(args.output_dir, 'net_skeleton_yen.ckpt'), 'wb') as f:        cPickle.dump(net_skeleton, f, protocol=cPickle.HIGHEST_PROTOCOL)        # Net weights.     net_weights = dict()    for name, item in net.params.iteritems():        net_weights[name + '/w'] = item[0].data.transpose(2, 3, 1, 0) # See the explanation on filter formats above.        net_weights[name + '/b'] = item[1].data    with open(os.path.join(args.output_dir,'net_weights_yen.ckpt'), 'wb') as f:        cPickle.dump(net_weights, f, protocol=cPickle.HIGHEST_PROTOCOL)    del net, net_skeleton, net_weightsif __name__ == '__main__':    main()

这个代码直接将一个新的网络用预训练好的caffemodel文件对其进行初始化,而返回的两个文件一个是list<保存的是参数的name和shape>保存文件名为‘net_skeleton_yen.ckpt’

另外一个是dirt<保存的时name 和相应的参数值> 保存文件名为‘net_weights_yen.ckpt’

需要注意的是,你新构建的prototxt文件的layer中的变量(weigts 和biases)的name 只有和生成caffemodel文件的所定义的一致时














原创粉丝点击