mxnet卷积神经网络训练MNIST数据集测试

来源:互联网 发布:MT7623GA芯片数据 编辑:程序博客网 时间:2024/05/16 17:54
import numpy as npimport mxnet as mximport logginglogging.getLogger().setLevel(logging.DEBUG)batch_size = 100mnist = mx.test_utils.get_mnist()train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)data = mx.sym.var('data') # first conv layerconv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)tanh1= mx.sym.Activation(data=conv1, act_type="tanh")pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))# second conv layerconv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)tanh2= mx.sym.Activation(data=conv2, act_type="tanh")pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))# first fullc layerflatten= mx.sym.Flatten(data=pool2)fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)tanh3= mx.sym.Activation(data=fc1, act_type="tanh")# second fullcfc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)# softmax losslenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')# create a trainable module on GPU 0lenet_model = mx.mod.Module(                symbol=lenet,                 context=mx.cpu())# train with the samelenet_model.fit(train_iter,                eval_data=val_iter,                optimizer='sgd',                optimizer_params={'learning_rate':0.1},                eval_metric='acc',                batch_end_callback = mx.callback.Speedometer(batch_size, 100),                num_epoch=10)

显示结果:

INFO:root:train-labels-idx1-ubyte.gz exists, skip to downloadaINFO:root:train-images-idx3-ubyte.gz exists, skip to downloadaINFO:root:t10k-labels-idx1-ubyte.gz exists, skip to downloadaINFO:root:t10k-images-idx3-ubyte.gz exists, skip to downloadaINFO:root:Epoch[0] Batch [100]  Speed: 722.13 samples/sec       accuracy=0.103366INFO:root:Epoch[0] Batch [200]  Speed: 713.60 samples/sec       accuracy=0.115500INFO:root:Epoch[0] Batch [300]  Speed: 714.94 samples/sec       accuracy=0.110900INFO:root:Epoch[0] Batch [400]  Speed: 709.44 samples/sec       accuracy=0.111200INFO:root:Epoch[0] Batch [500]  Speed: 714.26 samples/sec       accuracy=0.114600INFO:root:Epoch[0] Train-accuracy=0.113434INFO:root:Epoch[0] Time cost=83.928INFO:root:Epoch[0] Validation-accuracy=0.113500INFO:root:Epoch[1] Batch [100]  Speed: 716.48 samples/sec       accuracy=0.161683INFO:root:Epoch[1] Batch [200]  Speed: 675.00 samples/sec       accuracy=0.591100INFO:root:Epoch[1] Batch [300]  Speed: 668.75 samples/sec       accuracy=0.861500INFO:root:Epoch[1] Batch [400]  Speed: 647.97 samples/sec       accuracy=0.899400INFO:root:Epoch[1] Batch [500]  Speed: 666.97 samples/sec       accuracy=0.920600INFO:root:Epoch[1] Train-accuracy=0.932828INFO:root:Epoch[1] Time cost=88.947INFO:root:Epoch[1] Validation-accuracy=0.940800INFO:root:Epoch[2] Batch [100]  Speed: 660.08 samples/sec       accuracy=0.944653INFO:root:Epoch[2] Batch [200]  Speed: 650.96 samples/sec       accuracy=0.954200INFO:root:Epoch[2] Batch [300]  Speed: 669.57 samples/sec       accuracy=0.958800INFO:root:Epoch[2] Batch [400]  Speed: 644.97 samples/sec       accuracy=0.963200INFO:root:Epoch[2] Batch [500]  Speed: 654.75 samples/sec       accuracy=0.967100INFO:root:Epoch[2] Train-accuracy=0.969394INFO:root:Epoch[2] Time cost=91.671INFO:root:Epoch[2] Validation-accuracy=0.973100INFO:root:Epoch[3] Batch [100]  Speed: 660.64 samples/sec       accuracy=0.970990INFO:root:Epoch[3] Batch [200]  Speed: 669.49 samples/sec       accuracy=0.974400INFO:root:Epoch[3] Batch [300]  Speed: 650.88 samples/sec       accuracy=0.973900INFO:root:Epoch[3] Batch [400]  Speed: 665.29 samples/sec       accuracy=0.976800INFO:root:Epoch[3] Batch [500]  Speed: 664.31 samples/sec       accuracy=0.976000INFO:root:Epoch[3] Train-accuracy=0.978384INFO:root:Epoch[3] Time cost=90.576INFO:root:Epoch[3] Validation-accuracy=0.981600INFO:root:Epoch[4] Batch [100]  Speed: 657.94 samples/sec       accuracy=0.978416INFO:root:Epoch[4] Batch [200]  Speed: 651.82 samples/sec       accuracy=0.980100INFO:root:Epoch[4] Batch [300]  Speed: 653.96 samples/sec       accuracy=0.982100INFO:root:Epoch[4] Batch [400]  Speed: 647.17 samples/sec       accuracy=0.982400INFO:root:Epoch[4] Batch [500]  Speed: 656.77 samples/sec       accuracy=0.981900INFO:root:Epoch[4] Train-accuracy=0.984646INFO:root:Epoch[4] Time cost=91.804INFO:root:Epoch[4] Validation-accuracy=0.983400INFO:root:Epoch[5] Batch [100]  Speed: 649.50 samples/sec       accuracy=0.983069INFO:root:Epoch[5] Batch [200]  Speed: 649.20 samples/sec       accuracy=0.984600INFO:root:Epoch[5] Batch [300]  Speed: 647.68 samples/sec       accuracy=0.985200INFO:root:Epoch[5] Batch [400]  Speed: 658.71 samples/sec       accuracy=0.985900INFO:root:Epoch[5] Batch [500]  Speed: 646.41 samples/sec       accuracy=0.984900INFO:root:Epoch[5] Train-accuracy=0.987071INFO:root:Epoch[5] Time cost=92.219INFO:root:Epoch[5] Validation-accuracy=0.985100INFO:root:Epoch[6] Batch [100]  Speed: 645.74 samples/sec       accuracy=0.985842INFO:root:Epoch[6] Batch [200]  Speed: 653.40 samples/sec       accuracy=0.987800INFO:root:Epoch[6] Batch [300]  Speed: 646.12 samples/sec       accuracy=0.987800INFO:root:Epoch[6] Batch [400]  Speed: 641.82 samples/sec       accuracy=0.988100INFO:root:Epoch[6] Batch [500]  Speed: 643.05 samples/sec       accuracy=0.986900INFO:root:Epoch[6] Train-accuracy=0.989192INFO:root:Epoch[6] Time cost=96.044INFO:root:Epoch[6] Validation-accuracy=0.986100INFO:root:Epoch[7] Batch [100]  Speed: 653.00 samples/sec       accuracy=0.987327INFO:root:Epoch[7] Batch [200]  Speed: 650.61 samples/sec       accuracy=0.988800INFO:root:Epoch[7] Batch [300]  Speed: 649.02 samples/sec       accuracy=0.989100INFO:root:Epoch[7] Batch [400]  Speed: 644.93 samples/sec       accuracy=0.990000INFO:root:Epoch[7] Batch [500]  Speed: 554.87 samples/sec       accuracy=0.988700INFO:root:Epoch[7] Train-accuracy=0.990202INFO:root:Epoch[7] Time cost=94.743INFO:root:Epoch[7] Validation-accuracy=0.987600INFO:root:Epoch[8] Batch [100]  Speed: 649.92 samples/sec       accuracy=0.988812INFO:root:Epoch[8] Batch [200]  Speed: 654.07 samples/sec       accuracy=0.990800INFO:root:Epoch[8] Batch [300]  Speed: 656.73 samples/sec       accuracy=0.990700INFO:root:Epoch[8] Batch [400]  Speed: 653.70 samples/sec       accuracy=0.990900INFO:root:Epoch[8] Batch [500]  Speed: 631.36 samples/sec       accuracy=0.990200INFO:root:Epoch[8] Train-accuracy=0.991616INFO:root:Epoch[8] Time cost=92.349INFO:root:Epoch[8] Validation-accuracy=0.988500INFO:root:Epoch[9] Batch [100]  Speed: 647.88 samples/sec       accuracy=0.990792INFO:root:Epoch[9] Batch [200]  Speed: 635.89 samples/sec       accuracy=0.991900INFO:root:Epoch[9] Batch [300]  Speed: 637.18 samples/sec       accuracy=0.991700INFO:root:Epoch[9] Batch [400]  Speed: 640.23 samples/sec       accuracy=0.992300INFO:root:Epoch[9] Batch [500]  Speed: 640.93 samples/sec       accuracy=0.991900INFO:root:Epoch[9] Train-accuracy=0.992828INFO:root:Epoch[9] Time cost=93.533INFO:root:Epoch[9] Validation-accuracy=0.988700

原创粉丝点击