Mxnet图片分类(3)fine-tune

来源:互联网 发布:淘宝购物省钱技巧 编辑:程序博客网 时间:2024/06/05 23:07
   使用fine-tun的方式训练模型的话首先需要下载相应的模型,然后按照自己的数据集修改相应的类别,最后训练。

系统: ubuntu14.04
Mxnet: 0.904

1.数据准备

train_iter = "/mxnet/tools/train-cat.rec"val_iter = "/mxnet/tools/train-cat_test.rec"batch_size=10num_epoch = 40train_dataiter = mx.io.ImageRecordIter(            path_imgrec=train_iter,            #mean_img=datadir+"/mean.bin",            rand_crop=True,            rand_mirror=True,            data_shape=(3,224,224),            batch_size=batch_size,            preprocess_threads=1)test_dataiter = mx.io.ImageRecordIter(            path_imgrec=val_iter,            #mean_img=datadir+"/mean.bin",            rand_crop=False,            rand_mirror=False,            data_shape=(3,224,224),            batch_size=batch_size,            preprocess_threads=1)

2.加载fine-tune模型

模型可以通过Mxnet Model Zoo下载。这里下载的是vgg16

sym,arg_params,aux_params=mx.model.load_checkpoint('model/vgg16',0)

3.修改类别

def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='drop7'):    """    symbol: the pretrained network symbol    arg_params: the argument parameters of the pretrained model    num_classes: the number of classes for the fine-tune datasets    layer_name: the layer name before the last fully-connected layer    """    all_layers = symbol.get_internals()    net = all_layers[layer_name+'_output']#不要忘了'_output',vgg16的fc8的上一层是drop7    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc8_new')#fine-tune修改修改最后一层名字    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})    return (net, new_args)

4.训练模型

import logginghead = '%(asctime)-15s %(message)s'logging.basicConfig(level=logging.DEBUG, format=head)def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):    devs = [mx.gpu(i) for i in range(num_gpus)]    mod = mx.mod.Module(symbol=symbol, context=devs)    mod.fit(train, val,        num_epoch=8,        arg_params=arg_params,        aux_params=aux_params,        allow_missing=True,        batch_end_callback = mx.callback.Speedometer(batch_size, 10),        kvstore='device',        optimizer='sgd',        optimizer_params={'learning_rate':0.001},        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),        eval_metric='acc')    mod.save_checkpoint('./vggnew',num_epoch)#保存模型    metric = mx.metric.Accuracy()    return mod.score(val, metric)
num_classes = 2 #2类batch_per_gpu = 16num_gpus = 1(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)b = mx.viz.plot_network(new_sym)#可视化网络结构b.view()batch_size = batch_per_gpu * num_gpus#(train, val) = get_iterators(batch_size)mod_score = fit(new_sym, new_args, aux_params, train_dataiter, test_dataiter, batch_size, num_gpus)assert mod_score > 0.77, "Low training accuracy."

这里写图片描述

参考文献:

[1]http://mxnet.io/how_to/finetune.html

环境的安装和数据集的制作可以参考

  1. Mxnet—faster-rcnn环境安装
  2. Mxnet图片分类(1)准备数据集

测试可以参考:

  1. Mxnet图片分类(4)利用训练好的模型进行测试
原创粉丝点击