Mxnet多任务(multi-task)训练

来源:互联网 发布:mba培训网络机构 华章 编辑:程序博客网 时间:2024/06/03 21:56

github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。

1、数据制作
需要自己生成*.lst文件,里面内容如下:

index   task1标签   task2标签    task3标签    图片路径(这行是说明,不需要写入,每一列用\t隔开)2476    0.000000    0.000000    1.000000    photo_02_8159/00022552.jpg 7623    3.000000    2.000000    2.000000    photo_03_7397/00029434.jpg14149   0.000000    0.000000    1.000000    photo_05_15560/00060839.jpg6874    3.000000    1.000000    2.000000    photo_03_7397/00028414.jpg6048    0.000000    0.000000    1.000000    photo_02_8159/00027259.jpg14479   3.000000    3.000000    2.000000    photo_05_15560/00065068.jpg10429   2.000000    0.000000    1.000000    photo_04_15224/00040186.jpg6949    3.000000    0.000000    1.000000    photo_03_7397/00028521.jpg81      3.000000    3.000000    2.000000    photo_01_19992/00002536.jpg11725   2.000000    0.000000    1.000000    photo_05_15560/00051778.jpg1517    2.000000    3.000000    2.000000    photo_02_8159/00021245.jpg

具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):

    fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别    fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别    fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别    #分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到    s1 = mx.symbol.SoftmaxOutput(data=fc1, name='softmax1')     s2 = mx.symbol.SoftmaxOutput(data=fc2, name='softmax2')    s3 = mx.symbol.SoftmaxOutput(data=fc3, name='softmax3')    return  mx.symbol.Group([s1,s2,s3])

3、编写ImageRecordIter选项

    train = mx.io.ImageRecordIter(        path_imgrec='/path/to/train.rec',        label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],#label名称,于softmax名称一样,后面要加入_label        label_width=3, #重要,需要设置label宽度为3,因为有3个任务        data_shape=[3,224,224],        batch_size=64    )    val = mx.io.ImageRecordIter(        path_imgrec='/path/to/val.rec',        label_name=['softmax1_label', 'softmax2_label', 'softmax3_label'],        label_width=3,        batch_size=64,        data_shape=[3,224,224],    )

4、定义多任务训练迭代器

class MultiTask_iter(mx.io.DataIter):    def __init__(self, data_iter):        super(MultiTask_iter,self).__init__('multitask_iter')        self.data_iter = data_iter        self.batch_size = self.data_iter.batch_size    @property    def provide_data(self):        return self.data_iter.provide_data    @property    def provide_label(self):        provide_label = self.data_iter.provide_label[0]        # the name of the label if corresponding to the model you define in get_fine_tune_model() function        return [('softmax1_label', [provide_label[1][0]]),#需要注意的地方        ('softmax2_label', [provide_label[1][0]]),        ('softmax3_label', [provide_label[1][0]])]    def hard_reset(self):        self.data_iter.hard_reset()    def reset(self):        self.data_iter.reset()    def next(self):        batch = self.data_iter.next()        #需要注意的地方        label = batch.label[0]        ll = label.asnumpy()        label1 = mx.nd.array(ll[:,0]).astype('float32')        label2 = mx.nd.array(ll[:,1]).astype('float32')        label3 = mx.nd.array(ll[:,2]).astype('float32')        # we set task 2 as: if label>0 or not        return mx.io.DataBatch(data=batch.data, label=[label1,label2,label3], \                pad=batch.pad, index=batch.index)

5、定义正确率计算方法

class Multi_Accuracy(mx.metric.EvalMetric):    """Calculate accuracies of multi label"""    def __init__(self, num=None):        super(Multi_Accuracy, self).__init__('multi-accuracy', num)    def update(self, labels, preds):        mx.metric.check_label_shapes(labels, preds)        if self.num is not None:            assert len(labels) == self.num        for i in range(len(labels)):            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')            label = labels[i].asnumpy().astype('int32')            mx.metric.check_label_shapes(label, pred_label)            if i is None:                self.sum_metric += (pred_label.flat == label.flat).sum()                self.num_inst += len(pred_label.flat)            else:                self.sum_metric[i] += (pred_label.flat == label.flat).sum()                self.num_inst[i] += len(pred_label.flat)

6、训练

    train = MultiTask_iter(train)#调用多任务迭代器,其中train参数就是第3步的东西    val = MultiTask_iter(val)    new_sym = get_symbol(10,50,image_shape)    optimizer_params = {            'learning_rate': 0.001,            'momentum' : args.mom,            'wd' : args.wd,           }    initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)    model = mx.mod.Module(        context       = devs,        symbol        = new_sym,        data_names=['data'],        label_names=['softmax1_label','softmax2_label','softmax3_label']    )    saveroot = args.save_result+'/' + args.save_name    checkpoint = mx.callback.do_checkpoint(saveroot)    model.fit(train,              begin_epoch=0,              num_epoch=100000,              eval_data=val,              eval_metric=Multi_Accuracy(num=3),#需要注意的地方              optimizer='sgd',              optimizer_params=optimizer_params,              initializer=initializer,              allow_missing=True,              batch_end_callback=mx.callback.Speedometer(64, 50),              epoch_end_callback=checkpoint              )
原创粉丝点击