如何编写训练测试的prototxt配置文件---以Resnet为例

来源:互联网 发布:ai 软件下载 编辑:程序博客网 时间:2024/05/17 04:06
import osos.chdir('/home/wuwl/ResNet')import init_pathimport caffeimport numpy as npimport toolsfrom caffe import layers as L,params as P,to_protothis_dir = os.path.abspath(".")def ResNet(split):    train_data_file = this_dir + '/caffe-master/examples/cifar10/cifar10_train_lmdb'    test_data_file = this_dir + '/caffe-master/examples/cifar10/cifar10_test_lmdb'    mean_file = this_dir + '/caffe-master/examples/cifar10/mean.binaryproto'        if split == 'train':        data,labels = L.Data(source = train_data_file, #训练样本的路径                                backend = P.Data.LMDB, #训练样本的格式                                batch_size = 128,                                      ntop = 2,              # 输出的数目                                transform_param = dict(mean_file = mean_file,                                                         crop_size = 28, #只有训练才旋转                                                         mirror = True))    else:        data,labels = L.Data(source = test_data_file,   #测试样本的路径                                 backend = P.Data.LMDB,                                 batch_size = 128,                                 ntop = 2,                                 transform_param = dict(mean_file = mean_file,                                                          crop_size = 28))        repeat = 3    scale,result = conv_BN_scale_relu(split,data,nout = 16,ks = 3,stride = 1,pad = 1)    for i in range(repeat):        projection_stride = 1        result = ResNet_block(split,result,nout = 16,ks = 3,stride = 1,                              projection_stride = projection_stride,pad = 1)    for i in range(repeat):        if i == 0:            projection_stride = 2    #直通部分        else:            projection_stride = 1    #正常卷积        result = ResNet_block(split,result,nout = 32,ks = 3,stride = 1,                              projection_stride = projection_stride,pad = 1)    for i in range(repeat):        if i == 0:            projection_stride = 2    #直通部分        else:            projection_stride = 1    #正常卷积        result = ResNet_block(split,result,nout = 64,ks = 3,stride = 1,                              projection_stride = projection_stride,pad = 1)    pool = L.Pooling(result,pool = P.Pooling.AVE,global_pooling = True)    IP = L.InnerProduct(pool,num_output = 10,                        weight_filler = dict(type = 'xavier'),                        bias_filler = dict(type = 'constant'))    acc = L.Accuracy(IP,labels)    loss = L.SoftmaxWithLoss(IP,labels)    return to_proto(acc,loss)                                                                    def conv_BN_scale_relu(split,bottom,nout,ks,stride,pad):    conv = L.Convolution(bottom,kernel_size = ks,stride = stride,num_output = nout,                         pad = pad,bias_term = True,                         weight_filler = dict(type = 'xavier'),                         bias_filler = dict(type = 'constant'))    if split == "train":        use_global_stats = False    else:        use_global_stats = True    BN = L.BatchNorm(conv,batch_norm_param = dict(use_global_stats = use_global_stats),                     in_place = True,                     param = [dict(lr_mult = 0,decay_mult = 0),                              dict(lr_mult = 0,decay_mult = 0),                              dict(lr_mult = 0,decay_mult = 0)])    scale = L.Scale(BN,scale_param = dict(bias_term = True),in_place = True)    relu = L.ReLU(scale,in_place = True)    return scale,relu    def ResNet_block(split,bottom,nout,ks,stride,projection_stride,pad):    if projection_stride == 1:        scale0 = bottom    else:        scale0,relu0 = conv_BN_scale_relu(split,bottom,nout,1,projection_stride,0)            scale1,relu1 = conv_BN_scale_relu(split,bottom,nout,ks,projection_stride,pad)    scale2,relu2 = conv_BN_scale_relu(split,relu1,nout,ks,stride,pad)    wise = L.Eltwise(scale2,scale0,operation = P.Eltwise.SUM)    wise_relu = L.ReLU(wise,in_place = True)    return wise_relu                                           def make_net():    with open(this_dir + '/res_net_model/train.prototxt','w') as f:        f.write(str(ResNet('train')))            with open(this_dir + '/res_net_model/test.prototxt','w') as f:        f.write(str(ResNet('test')))        if __name__ == '__main__':    make_net()

1 0