MXNet | 手写字MNIST识别比赛

来源:互联网 发布:win32系统编程电子书 编辑:程序博客网 时间:2024/05/01 23:32

MNIST手写字图片数据集由Yann LeCun创建,每条数据表示28*28像素的图片。它已经是用于衡量分类器在简单图片作为输入的标准数据集。神经网络是对于图片分类任务来说是强大的模型。这是一个在kaggle长期举办的比赛数据集。

比赛的官网:https://www.kaggle.com/c/digit-recognizer

若是下载数据集困难,可以去我的百度网盘下载:链接:http://pan.baidu.com/s/1sl50KjV 密码:ca56

读取数据集,这里用readr中的函数read_csv,读取速度快高效

setwd("F:\\迅雷下载\\mnist")require(mxnet)library(readr)train <- read_csv('train.csv')test <- read_csv('test.csv')

数据集:训练集和测试集

> train <- data.matrix(train)> test <- data.matrix(test)> train.x <- train[,-1]> train.y <- train[,1]> train <- data.matrix(train)> test <- data.matrix(test)> train.x <- train[,-1]> train.y <- train[,1]

数据放缩到[0,1]

> train.x <- t(train.x/255)> test <- t(test/255)

标签

> table(train.y)train.y   0    1    2    3    4    5    6    7    8    9 4132 4684 4177 4351 4072 3795 4137 4401 4063 4188 

数据集还是比较平衡,不同之间的差异不大

构建网络

#定义> data <- mx.symbol.Variable("data")#第一层,全连接,隐藏节点128个> fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)#激活函数为relu> act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")#第二层,隐藏节点为64个> fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)#激活函数为relu> act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")#第三层,隐藏节点为10个> fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)##激活函数为sm,即softmax> softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")

训练,采用cpu的方式

#cpu>devices <- mx.cpu()#随机种子>mx.set.seed(0)#模型>model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,                                     ctx=devices, num.round=10, array.batch.size=100,                                     learning.rate=0.07, momentum=0.9,  eval.metric=mx.metric.accuracy,                                     initializer=mx.init.uniform(0.07),                                     epoch.end.callback=mx.callback.log.train.metric(100))Start training with 1 devices[1] Train-accuracy=0.859832935560859[2] Train-accuracy=0.957666666666668[3] Train-accuracy=0.971023809523813[4] Train-accuracy=0.977714285714289[5] Train-accuracy=0.981571428571432[6] Train-accuracy=0.986309523809527[7] Train-accuracy=0.988952380952383[8] Train-accuracy=0.990880952380956[9] Train-accuracy=0.992142857142861[10] Train-accuracy=0.991095238095241

训练的精度为99.10%

预测

> preds <- predict(model, test)> dim(preds)[1]    10 28000> pred.label <- max.col(t(preds)) - 1

预测后的类别

> table(pred.label)pred.label   0    1    2    3    4    5    6    7    8    9 2816 3216 2753 2791 2709 2544 2762 2836 2780 2793 

得到提交的数据集ID和label

submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)write.csv(submission, file='submission.csv', row.names=FALSE,  quote=FALSE)

submission.csv文件在你的工作目录下,然后去kaggle提交下。

登陆kaggle,打开页面https://www.kaggle.com/c/digit-recognizer/submissions/attach

提交

结果显示
结果

下面给出完整的代码:

setwd("F:\\迅雷下载\\mnist")require(mxnet)library(readr)train <- read_csv('train.csv')test <- read_csv('test.csv')train <- data.matrix(train)test <- data.matrix(test)train.x <- train[,-1]train.y <- train[,1]# 数据放缩到[0,1]train.x <- t(train.x/255)test <- t(test/255)table(train.y)#构建网络data <- mx.symbol.Variable("data")fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")########训练##cpudevices <- mx.cpu()mx.set.seed(0)model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y,                                     ctx=devices, num.round=10, array.batch.size=100,                                     learning.rate=0.07, momentum=0.9,  eval.metric=mx.metric.accuracy,                                     initializer=mx.init.uniform(0.07),                                     epoch.end.callback=mx.callback.log.train.metric(100))#预测preds <- predict(model, test)dim(preds)pred.label <- max.col(t(preds)) - 1table(pred.label)submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)write.csv(submission, file='submission.csv', row.names=FALSE,  quote=FALSE)
2 0