MXNet | 手写字MNIST识别比赛
来源:互联网 发布:win32系统编程电子书 编辑:程序博客网 时间:2024/05/01 23:32
MNIST手写字图片数据集由Yann LeCun创建,每条数据表示28*28像素的图片。它已经是用于衡量分类器在简单图片作为输入的标准数据集。神经网络是对于图片分类任务来说是强大的模型。这是一个在kaggle长期举办的比赛数据集。
比赛的官网:https://www.kaggle.com/c/digit-recognizer
若是下载数据集困难,可以去我的百度网盘下载:链接:http://pan.baidu.com/s/1sl50KjV 密码:ca56
读取数据集,这里用readr中的函数read_csv,读取速度快高效
setwd("F:\\迅雷下载\\mnist")require(mxnet)library(readr)train <- read_csv('train.csv')test <- read_csv('test.csv')
数据集:训练集和测试集
> train <- data.matrix(train)> test <- data.matrix(test)> train.x <- train[,-1]> train.y <- train[,1]> train <- data.matrix(train)> test <- data.matrix(test)> train.x <- train[,-1]> train.y <- train[,1]
数据放缩到[0,1]
> train.x <- t(train.x/255)> test <- t(test/255)
标签
> table(train.y)train.y 0 1 2 3 4 5 6 7 8 9 4132 4684 4177 4351 4072 3795 4137 4401 4063 4188
数据集还是比较平衡,不同之间的差异不大
构建网络
#定义> data <- mx.symbol.Variable("data")#第一层,全连接,隐藏节点128个> fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)#激活函数为relu> act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")#第二层,隐藏节点为64个> fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)#激活函数为relu> act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")#第三层,隐藏节点为10个> fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)##激活函数为sm,即softmax> softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")
训练,采用cpu的方式
#cpu>devices <- mx.cpu()#随机种子>mx.set.seed(0)#模型>model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y, ctx=devices, num.round=10, array.batch.size=100, learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy, initializer=mx.init.uniform(0.07), epoch.end.callback=mx.callback.log.train.metric(100))Start training with 1 devices[1] Train-accuracy=0.859832935560859[2] Train-accuracy=0.957666666666668[3] Train-accuracy=0.971023809523813[4] Train-accuracy=0.977714285714289[5] Train-accuracy=0.981571428571432[6] Train-accuracy=0.986309523809527[7] Train-accuracy=0.988952380952383[8] Train-accuracy=0.990880952380956[9] Train-accuracy=0.992142857142861[10] Train-accuracy=0.991095238095241
训练的精度为99.10%
预测
> preds <- predict(model, test)> dim(preds)[1] 10 28000> pred.label <- max.col(t(preds)) - 1
预测后的类别
> table(pred.label)pred.label 0 1 2 3 4 5 6 7 8 9 2816 3216 2753 2791 2709 2544 2762 2836 2780 2793
得到提交的数据集ID和label
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
submission.csv文件在你的工作目录下,然后去kaggle提交下。
登陆kaggle,打开页面https://www.kaggle.com/c/digit-recognizer/submissions/attach
结果显示
下面给出完整的代码:
setwd("F:\\迅雷下载\\mnist")require(mxnet)library(readr)train <- read_csv('train.csv')test <- read_csv('test.csv')train <- data.matrix(train)test <- data.matrix(test)train.x <- train[,-1]train.y <- train[,1]# 数据放缩到[0,1]train.x <- t(train.x/255)test <- t(test/255)table(train.y)#构建网络data <- mx.symbol.Variable("data")fc1 <- mx.symbol.FullyConnected(data, name="fc1", num_hidden=128)act1 <- mx.symbol.Activation(fc1, name="relu1", act_type="relu")fc2 <- mx.symbol.FullyConnected(act1, name="fc2", num_hidden=64)act2 <- mx.symbol.Activation(fc2, name="relu2", act_type="relu")fc3 <- mx.symbol.FullyConnected(act2, name="fc3", num_hidden=10)softmax <- mx.symbol.SoftmaxOutput(fc3, name="sm")########训练##cpudevices <- mx.cpu()mx.set.seed(0)model <- mx.model.FeedForward.create(softmax, X=train.x, y=train.y, ctx=devices, num.round=10, array.batch.size=100, learning.rate=0.07, momentum=0.9, eval.metric=mx.metric.accuracy, initializer=mx.init.uniform(0.07), epoch.end.callback=mx.callback.log.train.metric(100))#预测preds <- predict(model, test)dim(preds)pred.label <- max.col(t(preds)) - 1table(pred.label)submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)write.csv(submission, file='submission.csv', row.names=FALSE, quote=FALSE)
2 0
- MXNet | 手写字MNIST识别比赛
- Tensorflow | MNIST手写字识别
- MNIST手写字识别的TensorFlow实现
- 基于tensorflow的MNIST手写字识别
- tensorflow mnist数据集手写字识别
- MXNet | LeNet-5(卷积神经网络)用于手写字识别
- TensorFlow学习笔记(一)MNIST手写字识别
- caffe学习例子(一) mnist手写字识别
- 手写字识别C++
- 基于Opencv库中SVM模块的MNIST手写字识别数据库识别
- 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型
- 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型
- Python神经网络代码识别手写字的实现流程(一):加载mnist数据
- K-近邻:手写字识别
- 我的第一个svm程序:手写字识别
- 神经网络代码识别手写字(python3.4.3版本)
- Tensorflow框架下识别手写字神经网络代码
- svm手写字检测
- [干货]2017已来,最全面试总结——这些Android面试题你一定需要
- 51nod-1632B君的连通
- 【转】电视的未来
- 网站做登录认证怎么做?
- zephyer系统在STM32F411-Nucleo平台上运行和基于openOCD的裸机调试环境搭建
- MXNet | 手写字MNIST识别比赛
- 在chrome浏览器下 input的autocomplete="off"失效,导致的自动填充
- 【9923】对抗赛
- 【数学归纳法】【二分答案】17.1.24 T3 zhenhuan题解
- Java 项目、Node前端项目 gitignore文件
- spring装配Bean(基于xml)
- 2016工作生活记录
- 我的2016
- angularjs项目记录