Mxnet图片分类(3)fine-tune
来源:互联网 发布:淘宝购物省钱技巧 编辑:程序博客网 时间:2024/06/05 23:07
使用fine-tun的方式训练模型的话首先需要下载相应的模型,然后按照自己的数据集修改相应的类别,最后训练。
系统: ubuntu14.04
Mxnet: 0.904
1.数据准备
train_iter = "/mxnet/tools/train-cat.rec"val_iter = "/mxnet/tools/train-cat_test.rec"batch_size=10num_epoch = 40train_dataiter = mx.io.ImageRecordIter( path_imgrec=train_iter, #mean_img=datadir+"/mean.bin", rand_crop=True, rand_mirror=True, data_shape=(3,224,224), batch_size=batch_size, preprocess_threads=1)test_dataiter = mx.io.ImageRecordIter( path_imgrec=val_iter, #mean_img=datadir+"/mean.bin", rand_crop=False, rand_mirror=False, data_shape=(3,224,224), batch_size=batch_size, preprocess_threads=1)
2.加载fine-tune模型
模型可以通过Mxnet Model Zoo下载。这里下载的是vgg16
sym,arg_params,aux_params=mx.model.load_checkpoint('model/vgg16',0)
3.修改类别
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='drop7'): """ symbol: the pretrained network symbol arg_params: the argument parameters of the pretrained model num_classes: the number of classes for the fine-tune datasets layer_name: the layer name before the last fully-connected layer """ all_layers = symbol.get_internals() net = all_layers[layer_name+'_output']#不要忘了'_output',vgg16的fc8的上一层是drop7 net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc8_new')#fine-tune修改修改最后一层名字 net = mx.symbol.SoftmaxOutput(data=net, name='softmax') new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k}) return (net, new_args)
4.训练模型
import logginghead = '%(asctime)-15s %(message)s'logging.basicConfig(level=logging.DEBUG, format=head)def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus): devs = [mx.gpu(i) for i in range(num_gpus)] mod = mx.mod.Module(symbol=symbol, context=devs) mod.fit(train, val, num_epoch=8, arg_params=arg_params, aux_params=aux_params, allow_missing=True, batch_end_callback = mx.callback.Speedometer(batch_size, 10), kvstore='device', optimizer='sgd', optimizer_params={'learning_rate':0.001}, initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2), eval_metric='acc') mod.save_checkpoint('./vggnew',num_epoch)#保存模型 metric = mx.metric.Accuracy() return mod.score(val, metric)
num_classes = 2 #2类batch_per_gpu = 16num_gpus = 1(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)b = mx.viz.plot_network(new_sym)#可视化网络结构b.view()batch_size = batch_per_gpu * num_gpus#(train, val) = get_iterators(batch_size)mod_score = fit(new_sym, new_args, aux_params, train_dataiter, test_dataiter, batch_size, num_gpus)assert mod_score > 0.77, "Low training accuracy."
参考文献:
[1]http://mxnet.io/how_to/finetune.html
环境的安装和数据集的制作可以参考
- Mxnet—faster-rcnn环境安装
- Mxnet图片分类(1)准备数据集
测试可以参考:
- Mxnet图片分类(4)利用训练好的模型进行测试
阅读全文
1 0
- Mxnet图片分类(3)fine-tune
- MXNet的预训练:fine-tune.py源码详解
- 使用caffe fine-tune一个单标签图像分类模型
- 使用caffe fine-tune一个单标签图像分类模型
- 使用caffe fine-tune一个单标签图像分类模型
- fine-tune convolutional network
- caffe fine-tune策略
- digits fine-tune方法
- Caffe之fine tune
- tensorflow & keras fine tune
- caffe 学习之 Fine-tune
- CNN训练之fine tune
- Mxnet图片分类(1)准备数据集
- Mxnet图片分类(2)训练模型
- keras面向小数据集的图像分类(VGG-16基础上fine-tune)实现(附代码)
- caffe— 使用模型进行fine tune
- yolo源码解析及fine-tune
- caffe深度学习(一)fine-tune
- Liunx基础命令(2)
- C++封装POSIX 线程库(六)线程池
- 我所理解的依赖注入IOC
- Linux_Nginx_Tomcat 安装笔记
- 使用certbot续期ssl证书renew时遇到问题
- Mxnet图片分类(3)fine-tune
- Kotlin 官方学习教程之返回和跳转
- 标准IO与文件IO 的区别
- 栈的实现
- Android 获取栈最顶层Activity和Application Context解决方案
- 网络流&费用流模版
- PHPCMS使用教程:设置站点信息
- 《iOS开发笔记—自定义UIAlertController》
- cdnjs [记录]