简化mxnet数据输入及训练
来源:互联网 发布:mac安装完win7启动失败 编辑:程序博客网 时间:2024/06/05 06:05
mxnet数据输入比较麻烦,这里把它改成输入为numpy的数据类型,简化成类似tensorflow的输入类型。
代码是基于mnist的,所以需要先下载:见文章mnist读取... ,然后修改下面代码中,get_mnist()的dir
顺便说一句,现在很多的算法都直接在cifar和imagenet上测试,mnist已经不具有代表性,而且在mnist上测试很好的算法,在更真实的数据cifar和imagenet上往往无效。mnist只是为了方便。
下面是代码:
#encoding=utf-8import logginglogging.basicConfig(level=logging.INFO)import mxnet as mximport numpy as npimport gzip, structimport timeclass MXNetWrapper(): def __init__(self,net,devices=[mx.cpu()]): ''' net: type mx.symbol.Symbol,the training network architecture devices: the devices want to run network mx.cpu() or mx.gpu(),etc ''' assert isinstance(net,mx.symbol.Symbol) self.net=net self.model=None self.infer_model=None self.net=net self.devices=devices self.binded=False def _make_data(self,feed_inputs,feed_labels=None,batch_size=1): ''' feed_inputs: python dict{k:v,k:v...} feed_labels: python dict{k:v,k:v...} k: mx.symbol.Symbol v: numpy.ndarray ,shape (batch_size,x,y,z...) batch_size: one size of a batch return: mx.io.NDArrayIter ''' data={} for k in feed_inputs.keys(): assert isinstance(k,mx.symbol.Symbol) data[k.name]=feed_inputs.get(k) label=None if feed_labels is not None: label={} for k in feed_labels.keys(): assert isinstance(k,mx.symbol.Symbol) label[k.name]=feed_labels.get(k) mxd=mx.io.NDArrayIter(data=data, label=label, batch_size=batch_size) return mxd def train(self, feed_inputs, feed_labels, batch_size=100, learning_rate=0.0321, optimizer='sgd', momentum=0.9, eval_metric=mx.metric.CrossEntropy(), params_initializer=mx.initializer.Xavier(), sync_params=True): ''' train network feed_inputs: python dict{k:v,k:v...} feed_labels: python dict{k:v,k:v...} k: mx.symbol.Symbol v: numpy.ndarray ,shape (batch_size,x,y,z...) batch_size: one batch size learning_rate: the learning rate or learning step of this network optimizer: the way to optimize the network,sgd(stochastic gradient descent) eval_metric: evaluate the network by mx.metric.CrossEntropy() mx.metric.CrossEntropy(),mx.metric.Accuracy() or any else in mx.metric.* sync_params: synchronize params from multidivices ''' mxd=None if not self.binded: self.bind_batch_size=batch_size self.learning_rate=learning_rate self.optimizer=optimizer self.momentum=momentum self.eval_metric=eval_metric self.params_initializer=params_initializer self.data_names=[k.name for k in feed_inputs.keys()] self.label_names=[k.name for k in feed_labels.keys()] mxd=self._make_data(feed_inputs, feed_labels, batch_size) self.model=mx.module.Module(symbol = self.net, context=self.devices, data_names=self.data_names, label_names=self.label_names) self.model.bind(data_shapes=mxd.provide_data,label_shapes=mxd.provide_label) self.model.init_params(initializer=self.params_initializer) optimizer_params={'learning_rate':self.learning_rate, 'momentum':self.momentum} self.model.init_optimizer(optimizer=self.optimizer, optimizer_params=optimizer_params) self.binded=True if mxd is None: mxd=self._make_data(feed_inputs, feed_labels, batch_size) for data_batch in mxd: self.model.forward(data_batch,is_train=True) self.model.backward() self.model.update() self.model.update_metric(self.eval_metric,data_batch.label) for name, val in self.eval_metric.get_name_value(): print 'train %s:%f'%(name, val) if sync_params: self.sync_params() def evaluate(self,feed_inputs,feed_labels,eval_metric=mx.metric.Accuracy()): ''' evaluate the network feed_inputs: python dict{k:v,k:v...} feed_labels: python dict{k:v,k:v...} k: mx.symbol.Symbol v: numpy.ndarray ,shape (batch_size,x,y,z...) eval_metric:evaluate the network by mx.metric.CrossEntropy() mx.metric.CrossEntropy(),mx.metric.Accuracy() or any else in mx.metric.* ''' mxd=self._make_data(feed_inputs,feed_labels,self.bind_batch_size) print self.model.score(eval_data=mxd, eval_metric=eval_metric) def inference(self,feed_inputs): ''' feed_inputs: python dict{k:v,k:v...} k: mx.symbol.Symbol v: numpy.ndarray ,shape (batch_size,x,y,z...) return: numpy.ndarray ''' mxd=self._make_data(feed_inputs,batch_size=1) if self.infer_model is None: self.infer_model=mx.module.Module(symbol = self.net, context=self.devices, data_names=self.data_names, label_names=self.label_names) self.infer_model.bind(data_shapes=mxd.provide_data) arg_params, aux_params=self.model.get_params() self.infer_model.set_params(arg_params, aux_params, allow_missing=False) return self.infer_model.predict(mxd).asnumpy() def sync_params(self): ''' synchronized params across devices ''' arg_params, aux_params = self.model.get_params() self.model.set_params(arg_params, aux_params) @staticmethod def format_to_mx_data(data_batch): ''' change the color channels axe of the data to satisfy the mxnet's data format example:[batch_size,224,224,3] to [batch_size,3,224,224] data_batch: numpy.ndarray,shape (batch_size,x,y,z) ''' assert len(np.shape(data_batch))==4,"other shape no need to format" data_batch = np.swapaxes(data_batch, 1, 3) data_batch = np.swapaxes(data_batch, 2, 3) return data_batch def read_image(image_path): ''' read mnist image ''' with gzip.open(image_path,'rb') as fimg: magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(num,rows,cols) image=image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255 return imagedef read_label(label_path): ''' read mnist label ''' with gzip.open(label_path) as flbl: magic, num = struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) return labeldef get_mnist(): mnist_dir='../../dataset/%s' train_img =read_image(mnist_dir%'train-images-idx3-ubyte.gz') train_label =read_label(mnist_dir%'train-labels-idx1-ubyte.gz') val_img =read_image(mnist_dir%'t10k-images-idx3-ubyte.gz') val_label =read_label(mnist_dir%'t10k-labels-idx1-ubyte.gz') print 'load data finished' return train_img,train_label,val_img,val_labeldef build_network(): x = mx.sym.Variable('x') y = mx.sym.Variable('y') # first conv layer conv1 = mx.sym.Convolution(data=x, kernel=(4,4), num_filter=15) relu1 = mx.sym.Activation(data=conv1, act_type="relu") pool1 = mx.sym.Pooling(data=relu1, pool_type="max", kernel=(2,2), stride=(2,2)) # second conv layer conv2 = mx.sym.Convolution(data=pool1, kernel=(2,2), num_filter=15) relu2 = mx.sym.Activation(data=conv2, act_type="relu") pool2 = mx.sym.Pooling(data=relu2, pool_type="max", kernel=(2,2), stride=(2,2)) #flatten layer flat = mx.sym.Flatten(data=pool2) #fully connected #fc1 = mx.sym.FullyConnected(data=flat, num_hidden=100) #tanh1 = mx.sym.Activation(data=fc1, act_type="tanh") fc2 = mx.sym.FullyConnected(data=flat, num_hidden=10) # softmax loss net = mx.sym.SoftmaxOutput(data=fc2,label=y, name='softmax') return x,y,net def train_cnn(): batch_size=101 learning_rate=0.111 num_epochs=10 num_loop_data=2000 x,y,net=build_network() mnw=MXNetWrapper(net=net) train_img,train_label,val_img,val_label=get_mnist() loops=len(train_img)/num_loop_data datalabel=zip(train_img,train_label) for e in xrange(num_epochs): print 'epoch:',e np.random.shuffle(datalabel) dls=None for i in xrange(loops): dls=datalabel[i*num_loop_data:(i+1)*num_loop_data] dls=map(np.array,zip(*dls)) mnw.train(feed_inputs={x:dls[0]}, feed_labels={y:dls[1]}, batch_size=batch_size, learning_rate=learning_rate, eval_metric=mx.metric.Accuracy()) mnw.evaluate(feed_inputs={x:dls[0]}, feed_labels={y:dls[1]}, eval_metric=mx.metric.MSE()) result=mnw.inference(feed_inputs={x:val_img[:10]}) for index,v in enumerate(result): print np.argmax(v),'--',val_label[index] if __name__ == '__main__': train_cnn()如上,因为会涉及数据转换,train的数据可以设置的大一些,那样转换的时间会少一些。inference的时候需要重新建模型绑定数据shape不然如果输入数据和训练的绑定的数据shape不同可能报错(代码已经写了,这里使用没有问题)。
0 0
- 简化mxnet数据输入及训练
- Mxnet(3)-SSD训练自己的数据
- Mxnet(4)-fcn训练自己的数据
- mxnet从处理数据到开始训练
- 【问题 解决】mxnet训练mnist数据集的Train_accuracy很小
- mxnet 使用自己的图片数据训练CNN模型
- MXNet:训练自己的数据并做预测
- Mxnet训练自己的数据集并测试
- mxnet多层感知机训练MNIST数据集详解【转】
- mxnet卷积神经网络训练MNIST数据集测试
- mxnet 之 目标检测数据集制作+SSD 模型训练
- mxnet利用下载好的mnist数据训练cnn
- MXNet 多rec参与训练
- mxnet 训练--如何生成rec 数据 +自己在本机测试的结果
- MXNet系统上ilsvrc12数据集的制作与inception-bn网络的训练
- mxnet实战笔记(1) - 使用自己的图片数据训练CNN模型
- 数据显示及训练方法
- MXNet数据加载
- POJ-2251 Dungeon Master (BFS)
- Python快速入门教程2:Python 面向对象
- centos6安装mysql
- 使用maven命令行创建web项目
- 扩展欧几里德算法Java实现和青蛙相遇
- 简化mxnet数据输入及训练
- 《机器学习实战》读书笔记5:朴素贝叶斯分类器的原理
- Ubuntu 常用软件安装与Android环境配置
- 不要怂,就是GAN (生成式对抗网络) (五):无约束条件的 GAN
- 文章标题
- gcd与exgcd
- 蓝桥杯——圆的面积:为什么面积要用double定义,用float会报错
- 进程调度算法
- [BZOJ1564][NOI2009]二叉查找树 动态规划