Mxnet多任务(multi-task)训练
来源:互联网 发布:mba培训网络机构 华章 编辑:程序博客网 时间:2024/06/03 21:56
github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。
1、数据制作
需要自己生成*.lst文件,里面内容如下:
index task1标签 task2标签 task3标签 图片路径(这行是说明,不需要写入,每一列用\t隔开)2476 0.000000 0.000000 1.000000 photo_02_8159/00022552.jpg 7623 3.000000 2.000000 2.000000 photo_03_7397/00029434.jpg14149 0.000000 0.000000 1.000000 photo_05_15560/00060839.jpg6874 3.000000 1.000000 2.000000 photo_03_7397/00028414.jpg6048 0.000000 0.000000 1.000000 photo_02_8159/00027259.jpg14479 3.000000 3.000000 2.000000 photo_05_15560/00065068.jpg10429 2.000000 0.000000 1.000000 photo_04_15224/00040186.jpg6949 3.000000 0.000000 1.000000 photo_03_7397/00028521.jpg81 3.000000 3.000000 2.000000 photo_01_19992/00002536.jpg11725 2.000000 0.000000 1.000000 photo_05_15560/00051778.jpg1517 2.000000 3.000000 2.000000 photo_02_8159/00021245.jpg
具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):
fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别 fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别 fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别 #分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到 s1 = mx.symbol.SoftmaxOutput(data=fc1, name='softmax1') s2 = mx.symbol.SoftmaxOutput(data=fc2, name='softmax2') s3 = mx.symbol.SoftmaxOutput(data=fc3, name='softmax3') return mx.symbol.Group([s1,s2,s3])
3、编写ImageRecordIter选项
train = mx.io.ImageRecordIter( path_imgrec='/path/to/train.rec', label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],#label名称,于softmax名称一样,后面要加入_label label_width=3, #重要,需要设置label宽度为3,因为有3个任务 data_shape=[3,224,224], batch_size=64 ) val = mx.io.ImageRecordIter( path_imgrec='/path/to/val.rec', label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'], label_width=3, batch_size=64, data_shape=[3,224,224], )
4、定义多任务训练迭代器
class MultiTask_iter(mx.io.DataIter): def __init__(self, data_iter): super(MultiTask_iter,self).__init__('multitask_iter') self.data_iter = data_iter self.batch_size = self.data_iter.batch_size @property def provide_data(self): return self.data_iter.provide_data @property def provide_label(self): provide_label = self.data_iter.provide_label[0] # the name of the label if corresponding to the model you define in get_fine_tune_model() function return [('softmax1_label', [provide_label[1][0]]),#需要注意的地方 ('softmax2_label', [provide_label[1][0]]), ('softmax3_label', [provide_label[1][0]])] def hard_reset(self): self.data_iter.hard_reset() def reset(self): self.data_iter.reset() def next(self): batch = self.data_iter.next() #需要注意的地方 label = batch.label[0] ll = label.asnumpy() label1 = mx.nd.array(ll[:,0]).astype('float32') label2 = mx.nd.array(ll[:,1]).astype('float32') label3 = mx.nd.array(ll[:,2]).astype('float32') # we set task 2 as: if label>0 or not return mx.io.DataBatch(data=batch.data, label=[label1,label2,label3], \ pad=batch.pad, index=batch.index)
5、定义正确率计算方法
class Multi_Accuracy(mx.metric.EvalMetric): """Calculate accuracies of multi label""" def __init__(self, num=None): super(Multi_Accuracy, self).__init__('multi-accuracy', num) def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) if self.num is not None: assert len(labels) == self.num for i in range(len(labels)): pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') label = labels[i].asnumpy().astype('int32') mx.metric.check_label_shapes(label, pred_label) if i is None: self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) else: self.sum_metric[i] += (pred_label.flat == label.flat).sum() self.num_inst[i] += len(pred_label.flat)
6、训练
train = MultiTask_iter(train)#调用多任务迭代器,其中train参数就是第3步的东西 val = MultiTask_iter(val) new_sym = get_symbol(10,50,image_shape) optimizer_params = { 'learning_rate': 0.001, 'momentum' : args.mom, 'wd' : args.wd, } initializer = mx.init.Xavier(factor_type="in", magnitude=2.34) model = mx.mod.Module( context = devs, symbol = new_sym, data_names=['data'], label_names=['softmax1_label','softmax2_label','softmax3_label'] ) saveroot = args.save_result+'/' + args.save_name checkpoint = mx.callback.do_checkpoint(saveroot) model.fit(train, begin_epoch=0, num_epoch=100000, eval_data=val, eval_metric=Multi_Accuracy(num=3),#需要注意的地方 optimizer='sgd', optimizer_params=optimizer_params, initializer=initializer, allow_missing=True, batch_end_callback=mx.callback.Speedometer(64, 50), epoch_end_callback=checkpoint )
阅读全文
0 0
- Mxnet多任务(multi-task)训练
- 多任务学习(Multi-task learning)
- Multi-task learning(多任务学习)简介
- 多任务学习(Multi-Task Learning, MTL)
- 多任务学习(Multi-task learning)-1
- 多任务学习方法( Multi-task learning )介绍
- 两个Multi-task learning(多任务学习)的代码
- 17.3.13 多任务学习 Multi-task learning
- MXNet 多rec参与训练
- 迁移学习(transfer learning)、多任务学习(multi-task learning)、深度学习(deep learning)概念摘抄
- 多任务学习“Facial Landmark Detection by Deep Multi-task Learning”
- 多任务学习“Fine-grained Recognition in the Wild: A Multi-Task Domain Adaptation Approach”
- 人脸属性多任务学习:Heterogeneous Face Attribute Estimation: A Deep Multi-Task Learning Approach
- 多任务学习“Rotating Your Face Using Multi-task Deep Neural Network”
- 多任务学习“Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics”
- Multi-task
- task多任务处理
- Task(任务)
- python @property,description,修饰符用法
- linux vi命令详解
- HTTP协议—— 简单认识TCP/IP协议
- ros kinect v2安装 编译错误及解决
- Token的简单解释
- Mxnet多任务(multi-task)训练
- LeetCode_21_Merge Two Sorted Lists
- 如何预防久坐伤身?
- 【Java集合源码剖析】Java集合框架
- 使用 webpack 优化资源
- poi的word文档结构介绍
- 【原】Android
- 集合案例---模拟斗地主发牌
- Python笔记1:开始