89.89% on CIFAR-10 in Pytorch
来源:互联网 发布:亚马逊德国站情况知乎 编辑:程序博客网 时间:2024/06/12 06:36
The full code is available here, just clone it to your machine and it’s ready to play. As a former Torch7 user, I attempt to reproduce the results from the Torch7 post.
My friends Wu Jun and Zhang Yujing claimed Batch Normalization[1] useless. I want to prove them wrong (打他们脸), and CIFAR-10 is a nice playground to start.
CIFAR-10 contains 60 000 labeled for 10 classes images 32x32 in size, train set has 50 000 and test set 10 000.
The dataset is quite small by today’s standards, but still a good playground for machine learning algorithms. I just use horizontal flips to augment data. One would need an NVIDIA GPU with at least 3 GB of memory.
The post and the code consist of 2 parts/files:
- model definition
- training
The model Vgg.py
It’s a VGG16-like[2] (not identical, I remove the first FC layer) network with many 3x3 filters and padding 1,1 so the sizes of feature maps after them are unchanged. They are only changed after max-pooling. Weights of convolutional layers are initialized MSR-style. Batch Normalization and Dropout are used together.
Training train.py
That’s it, you can start training:
python train.py
The parameters with which models achieves the best performance are default in the code. I used SGD (a little out-of-date) with cross-entropy loss with learning 0.01, momentum 0.9 and weight decay 0.0005, dropping learning rate every 25 epochs. After a few hours you will have the model. The accuracy record and models at each checkpoint are saved in ‘save’ folder.
How accuracy improves:
The best accuracy is 89.89%, removing BN or Dropout results in 88.67% and 88.73% accuracy, respectively. Batch Normalization can accelerate deep network training. Removing BN and Dropout results in 86.65% accuracy and we can observe the overfitting.
References
- Sergey Ioffe, Christian Szegedy. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. [arxiv]
- K. Simonyan, A. Zisserman. Very Deep Convolutional Networks for Large-Scale Image Recognition [arxiv]
- 89.89% on CIFAR-10 in Pytorch
- cs231n:assignment2——Q4: ConvNet on CIFAR-10
- CIFAR-10
- PyTorch读取Cifar数据集并显示图片
- PyTorch读取Cifar数据集并显示图片(转载)
- 【pytorch源码赏析】Dataset in pytorch
- Caffe学习-CIFAR-10
- cifar 10 最高正确率
- caffe CIFAR-10
- Resnet Cifar-10调试
- caffe学习:CIFAR-10
- 深度学习 :CIFAR-10
- caffe CIFAR 10 database
- CIFAR-10训练模型
- cifar
- 在Pytorch中实现im2col操作 Implementing im2col in Pytorch
- pytorch torchvision.datasets.CocoCaptions on Linux
- 用python读取cifar-10与cifar-100图像数据
- Apache的BeanUtils组件学习
- 2017/9/21训练日记
- django 外键操作
- HTTP常见状态码及表示意义
- DFS序+线段树
- 89.89% on CIFAR-10 in Pytorch
- 51nod1161【组合数学-杨辉三角】
- jQuey表单重置
- Course List for Student (25)
- Linux(CentOS7.3)使用yum安装MySQL
- Java new一个子类对象时static和构造函数的执行顺序
- vue axios POST请求中参数以form data和request payload形式的原因
- SCUT Training 20170920 Problem F
- 1013. 数素数 (20)