mxnet多层感知器、卷积神经网络测试【转】

来源:互联网 发布:淘宝网的女款加绒衬衣 编辑:程序博客网 时间:2024/05/22 03:47

来自:http://blog.csdn.net/xinfeng2005/article/details/53380700?locationNum=8&fps=1

# coding=utf-8import mxnet as mximport matplotlib.pyplot as pltimport numpy as npimport structimport pickledef ImageToFloat(img):    return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255#选择lenetmodel=mx.model.FeedForward.load('lenet',10)#选择MLP#model=mx.model.FeedForward.load('mpl_mnist',10)# 测试集单张图像识别情况with open('./mnist/t10k-labels-idx1-ubyte')as flbl:    magic, num = struct.unpack(">II", flbl.read(8))    label = np.fromstring(flbl.read(), dtype=np.int8)with open('./mnist/t10k-images-idx3-ubyte', 'rb')as fimg:    magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))    image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)plt.subplot(5,5,1)for x in range(25):    plt.subplot(5,5,x+1)    plt.imshow(255-image[x], cmap='Greys_r')    prob = model.predict(ImageToFloat(image[x:x+1]))[0]    print'Classified as %d with probability %f' % (prob.argmax(), max(prob))    plt.title('%s %s'%(str(label[x]),str(max(prob))))    plt.axis('off')plt.show()val_iter = mx.io.NDArrayIter(ImageToFloat(image), label, batch_size=100)print'Text accuracy: %f%%' % (model.score(val_iter) * 100,)#训练集识别精度with open('./mnist/train-labels-idx1-ubyte')as flbl:    magic, num = struct.unpack(">II", flbl.read(8))    label_train = np.fromstring(flbl.read(), dtype=np.int8)with open('./mnist/train-images-idx3-ubyte', 'rb')as fimg:    magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))    image_train = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label_train), rows, cols)plt.axis('off')plt.imshow(255-image_train[0], cmap='Greys_r')prob = model.predict(ImageToFloat(image_train[0:1]))[0]print'Classified as %d with probability %f' % (prob.argmax(), max(prob))plt.show()train_iter = mx.io.NDArrayIter(ImageToFloat(image_train), label_train, batch_size=100)print'Train accuracy: %f%%' % (model.score(train_iter) * 100,)

迭代10次后
MLP:精度Text accuracy: 97.390000% Train accuracy: 98.821667%
LeNet:精度Text accuracy: 99.170000% Train accuracy: 99.995000%

阅读全文
0 0
原创粉丝点击