PyTorch —— LeNet实现中的bug以及由此的小想法
来源:互联网 发布:网络教育68所学校 编辑:程序博客网 时间:2024/06/07 01:17
经典的LeNet,在PyTorch/examples/mnist 实现中有个小问题,在这里和大家分享一下。
是一个计算generalization error的问题。
计算generalization error时,原代码有一行我一直不理解,test_loss /= len(test_loader) # loss function already averages over batch size
。试了一下输出,就发现len(test_loader)
指的是进行一次data pass,会将所有的数据分割成多少份的mini-batch。然后配合原来的一行代码,在每次mini-batch计算loss的时候,test_loss += F.nll_loss(output, target).data[0]
,就可以推出原代码在求解generalization error的逻辑是这样的:
- 将一个data pass分成几个mini-batch
- 每一个mini-batch,
F.nll_loss(output, target).data[0]
的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True
的参数(后面会用到) - 进行一次data pass之后,就可以将每一个mini-batch average loss求和
- 将loss之和再除以mini-batch的数量,就得到最后的data point average loss
所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。
更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:
for each mini-batch: ... test_loss += F.nll_loss(output, target, size_average=False).data[0] ......test_loss /= len(test_loader.dataset)...
第一次给pub repo做pull request,感觉挺不错。PyTorch还有很多坑没填,大家感兴趣的可以慢慢填。我看到很多人直接把原bug照搬到自己的repo,也不知道有没有发现这个问题。其实只要深究test_loss /= len(test_loader)
这一行就能发现。
附:因为我pull request已经被merge到code base里,感兴趣的朋友可以移步到github. 类似的问题在mnist_hogwild
也有,成功merge github.
- PyTorch —— LeNet实现中的bug以及由此的小想法
- PyTorch(总)——PyTorch遇到令人迷人的BUG与记录
- web.xml中的jsp-config元素以及由此想到的
- keras的lenet实现
- PyTorch学习—PyTorch是什么?
- LeNet的C语言实现
- 数位dp回顾以及自己一些的小想法
- 关于leaflet的插件,leaflet_smoothmarkerbounce使用中的小bug,以及解决方法
- LeNet在caffe中的实现分析
- Twisted TimerService的使用(以及由此带来的诡异事件)
- 虚函数 以及 由此用到的函数指针
- C++ string::size_type 类型以及由此引发的思考
- 理解程序调用以及由此引出的缓冲区攻击问题
- 搜索的小想法
- VC6 的小BUG —— fmodf()
- pytorch+lstm实现的pos
- 几个小想法——不专心
- 小想法——生物计算机
- Linux_170709 守护进程u
- 面试题 23: 从上到下打印二叉树
- android的ScrollView的简单使用
- 【程序员面试宝典】栈和队列相关面试题
- 【CRM项目01】登陆功能实现
- PyTorch —— LeNet实现中的bug以及由此的小想法
- Python学习一——Python下载安装
- String类中replaceAll方法不能替换美元符号$的问题解决
- Servlet学习笔记 -- day02 Request
- 接口
- 走进Vue.js
- IDEA使用--字体、编码和基本设置
- 测sort与qsort耗时
- define与const 比较