FCN中的transplant

来源:互联网 发布:淘宝网冬之恋羊绒线 编辑:程序博客网 时间:2024/06/08 23:22

FCN中的surgery.transplant函数用于拷贝learnable参数,其直接目的是:将VGG分类模型中的一些全连接层的参数正确地拷贝到相应的目标全连接层中。代码如下:

def transplant(new_net, net, suffix=''):    """    Transfer weights by copying matching parameters, coercing parameters of    incompatible shape, and dropping unmatched parameters.    The coercion is useful to convert fully connected layers to their    equivalent convolutional layers, since the weights are the same and only    the shapes are different.  In particular, equivalent fully connected and    convolution layers have shapes O x I and O x I x H x W respectively for O    outputs channels, I input channels, H kernel height, and W kernel width.    Both  `net` to `new_net` arguments must be instantiated `caffe.Net`s.    """    for p in net.params:        p_new = p + suffix        if p_new not in new_net.params:            print 'dropping', p            continue        for i in range(len(net.params[p])):            if i > (len(new_net.params[p_new]) - 1):                print 'dropping', p, i                break            if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:                print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape            else:                print 'copying', p, ' -> ', p_new, i            new_net.params[p_new][i].data.flat = net.params[p][i].data.flat

ndarray.flat返回flatiter对象,即
这里写图片描述
surgery.transplant的调用方式在solve.py中:surgery.transplant(solver.net,vgg_net)

import sys    sys.path.append('/home/my/caffe-master/caffe-master/python')  import caffe  import surgery, score  import numpy as np  import os  import sys  try:      import setproctitle      setproctitle.setproctitle(os.path.basename(os.getcwd()))  except:      pass  vgg_weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'  vgg_proto = '../ilsvrc-nets/VGG_ILSVRC_16_layers_deploy.prototxt'  weights = '../ilsvrc-nets/vgg16-fcn.caffemodel'  # init  caffe.set_mode_gpu()  # caffe.set_device(int(sys.argv[0]))  caffe.set_device(7)  #solver = caffe.SGDSolver('solver.prototxt')  #solver.net.copy_from(weights)  solver = caffe.SGDSolver('solver.prototxt')  vgg_net=caffe.Net(vgg_proto,vgg_weights,caffe.TRAIN)  surgery.transplant(solver.net,vgg_net)  del vgg_net  # surgeries  interp_layers = [k for k in solver.net.params.keys() if 'up' in k]  surgery.interp(solver.net, interp_layers)  # scoring  test = np.loadtxt('../data/sift-flow/test.txt', dtype=str)  for _ in range(50):      solver.step(2000)      # N.B. metrics on the semantic labels are off b.c. of missing classes;      # score manually from the histogram instead for proper evaluation      score.seg_tests(solver, False, test, layer='score_sem', gt='sem')      score.seg_tests(solver, False, test, layer='score_geo', gt='geo') 
0 0
原创粉丝点击