简化mxnet数据输入及训练

来源:互联网 发布:mac安装完win7启动失败 编辑:程序博客网 时间:2024/06/05 06:05

mxnet数据输入比较麻烦,这里把它改成输入为numpy的数据类型,简化成类似tensorflow的输入类型。

代码是基于mnist的,所以需要先下载:见文章mnist读取...  ,然后修改下面代码中,get_mnist()的dir

顺便说一句,现在很多的算法都直接在cifar和imagenet上测试,mnist已经不具有代表性,而且在mnist上测试很好的算法,在更真实的数据cifar和imagenet上往往无效。mnist只是为了方便。

下面是代码:

#encoding=utf-8import logginglogging.basicConfig(level=logging.INFO)import mxnet as mximport numpy as npimport gzip, structimport timeclass MXNetWrapper():    def __init__(self,net,devices=[mx.cpu()]):        '''        net: type mx.symbol.Symbol,the training network architecture        devices: the devices want to run network mx.cpu() or mx.gpu(),etc        '''        assert isinstance(net,mx.symbol.Symbol)        self.net=net        self.model=None        self.infer_model=None        self.net=net        self.devices=devices        self.binded=False            def _make_data(self,feed_inputs,feed_labels=None,batch_size=1):        '''        feed_inputs: python dict{k:v,k:v...}        feed_labels: python dict{k:v,k:v...}        k: mx.symbol.Symbol        v: numpy.ndarray ,shape (batch_size,x,y,z...)        batch_size: one size of a batch                return: mx.io.NDArrayIter        '''        data={}        for k in feed_inputs.keys():            assert isinstance(k,mx.symbol.Symbol)            data[k.name]=feed_inputs.get(k)                    label=None        if feed_labels is not None:            label={}            for k in feed_labels.keys():                assert isinstance(k,mx.symbol.Symbol)                label[k.name]=feed_labels.get(k)                mxd=mx.io.NDArrayIter(data=data,                              label=label,                              batch_size=batch_size)        return mxd        def train(self,              feed_inputs,              feed_labels,              batch_size=100,              learning_rate=0.0321,              optimizer='sgd',              momentum=0.9,              eval_metric=mx.metric.CrossEntropy(),              params_initializer=mx.initializer.Xavier(),              sync_params=True):        '''        train network        feed_inputs: python dict{k:v,k:v...}        feed_labels: python dict{k:v,k:v...}        k: mx.symbol.Symbol        v: numpy.ndarray ,shape (batch_size,x,y,z...)        batch_size: one batch size        learning_rate: the learning rate or learning step of this network         optimizer: the way to optimize the network,sgd(stochastic gradient descent)        eval_metric: evaluate the network by mx.metric.CrossEntropy()                    mx.metric.CrossEntropy(),mx.metric.Accuracy() or any else in mx.metric.*        sync_params: synchronize params from multidivices          '''        mxd=None        if not self.binded:            self.bind_batch_size=batch_size            self.learning_rate=learning_rate            self.optimizer=optimizer            self.momentum=momentum            self.eval_metric=eval_metric            self.params_initializer=params_initializer            self.data_names=[k.name for k in feed_inputs.keys()]            self.label_names=[k.name for k in feed_labels.keys()]            mxd=self._make_data(feed_inputs, feed_labels, batch_size)            self.model=mx.module.Module(symbol = self.net,                                        context=self.devices,                                        data_names=self.data_names,                                        label_names=self.label_names)            self.model.bind(data_shapes=mxd.provide_data,label_shapes=mxd.provide_label)            self.model.init_params(initializer=self.params_initializer)            optimizer_params={'learning_rate':self.learning_rate, 'momentum':self.momentum}            self.model.init_optimizer(optimizer=self.optimizer, optimizer_params=optimizer_params)            self.binded=True                if mxd is None:            mxd=self._make_data(feed_inputs, feed_labels, batch_size)                    for data_batch in mxd:            self.model.forward(data_batch,is_train=True)            self.model.backward()            self.model.update()            self.model.update_metric(self.eval_metric,data_batch.label)                    for name, val in self.eval_metric.get_name_value():                print 'train %s:%f'%(name, val)        if sync_params:            self.sync_params()                def evaluate(self,feed_inputs,feed_labels,eval_metric=mx.metric.Accuracy()):        '''        evaluate the network                feed_inputs: python dict{k:v,k:v...}        feed_labels: python dict{k:v,k:v...}        k: mx.symbol.Symbol        v: numpy.ndarray ,shape (batch_size,x,y,z...)                eval_metric:evaluate the network by mx.metric.CrossEntropy()                    mx.metric.CrossEntropy(),mx.metric.Accuracy() or any else in mx.metric.*        '''        mxd=self._make_data(feed_inputs,feed_labels,self.bind_batch_size)        print self.model.score(eval_data=mxd, eval_metric=eval_metric)            def inference(self,feed_inputs):        '''        feed_inputs: python dict{k:v,k:v...}        k: mx.symbol.Symbol        v: numpy.ndarray ,shape (batch_size,x,y,z...)                return: numpy.ndarray        '''        mxd=self._make_data(feed_inputs,batch_size=1)        if self.infer_model is None:            self.infer_model=mx.module.Module(symbol = self.net,                                        context=self.devices,                                        data_names=self.data_names,                                        label_names=self.label_names)            self.infer_model.bind(data_shapes=mxd.provide_data)        arg_params, aux_params=self.model.get_params()        self.infer_model.set_params(arg_params, aux_params, allow_missing=False)        return self.infer_model.predict(mxd).asnumpy()            def sync_params(self):        '''        synchronized params across devices        '''        arg_params, aux_params = self.model.get_params()        self.model.set_params(arg_params, aux_params)        @staticmethod    def format_to_mx_data(data_batch):        '''        change the color channels axe of the data        to satisfy the mxnet's data format                example:[batch_size,224,224,3] to [batch_size,3,224,224]                  data_batch: numpy.ndarray,shape (batch_size,x,y,z)        '''        assert len(np.shape(data_batch))==4,"other shape no need to format"        data_batch = np.swapaxes(data_batch, 1, 3)        data_batch = np.swapaxes(data_batch, 2, 3)        return data_batch    def read_image(image_path):    '''    read mnist image    '''    with gzip.open(image_path,'rb') as fimg:        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(num,rows,cols)    image=image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255    return imagedef read_label(label_path):    '''    read mnist label    '''    with gzip.open(label_path) as flbl:        magic, num = struct.unpack(">II", flbl.read(8))        label = np.fromstring(flbl.read(), dtype=np.int8)    return labeldef get_mnist():    mnist_dir='../../dataset/%s'    train_img   =read_image(mnist_dir%'train-images-idx3-ubyte.gz')    train_label =read_label(mnist_dir%'train-labels-idx1-ubyte.gz')        val_img   =read_image(mnist_dir%'t10k-images-idx3-ubyte.gz')    val_label =read_label(mnist_dir%'t10k-labels-idx1-ubyte.gz')    print 'load data finished'    return train_img,train_label,val_img,val_labeldef build_network():    x = mx.sym.Variable('x')    y = mx.sym.Variable('y')        # first conv layer    conv1 = mx.sym.Convolution(data=x, kernel=(4,4), num_filter=15)    relu1 = mx.sym.Activation(data=conv1, act_type="relu")    pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2))    # second conv layer    conv2 = mx.sym.Convolution(data=pool1, kernel=(2,2), num_filter=15)    relu2 = mx.sym.Activation(data=conv2, act_type="relu")    pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2))        #flatten layer    flat  = mx.sym.Flatten(data=pool2)        #fully connected    #fc1   = mx.sym.FullyConnected(data=flat, num_hidden=100)    #tanh1 = mx.sym.Activation(data=fc1, act_type="tanh")        fc2   = mx.sym.FullyConnected(data=flat, num_hidden=10)    # softmax loss    net = mx.sym.SoftmaxOutput(data=fc2,label=y, name='softmax')        return x,y,net    def train_cnn():    batch_size=101    learning_rate=0.111    num_epochs=10    num_loop_data=2000    x,y,net=build_network()        mnw=MXNetWrapper(net=net)        train_img,train_label,val_img,val_label=get_mnist()        loops=len(train_img)/num_loop_data    datalabel=zip(train_img,train_label)    for e in xrange(num_epochs):        print 'epoch:',e        np.random.shuffle(datalabel)        dls=None        for i in xrange(loops):            dls=datalabel[i*num_loop_data:(i+1)*num_loop_data]            dls=map(np.array,zip(*dls))            mnw.train(feed_inputs={x:dls[0]},                      feed_labels={y:dls[1]},                      batch_size=batch_size,                      learning_rate=learning_rate,                      eval_metric=mx.metric.Accuracy())                    mnw.evaluate(feed_inputs={x:dls[0]},                     feed_labels={y:dls[1]},                     eval_metric=mx.metric.MSE())                result=mnw.inference(feed_inputs={x:val_img[:10]})        for index,v in enumerate(result):            print np.argmax(v),'--',val_label[index]        if __name__ == '__main__':    train_cnn()    

如上,因为会涉及数据转换,train的数据可以设置的大一些,那样转换的时间会少一些。inference的时候需要重新建模型绑定数据shape不然如果输入数据和训练的绑定的数据shape不同可能报错(代码已经写了,这里使用没有问题)。


0 0