PyTorch —— LeNet实现中的bug以及由此的小想法

来源:互联网 发布:网络教育68所学校 编辑:程序博客网 时间:2024/06/07 01:17

经典的LeNet,在PyTorch/examples/mnist 实现中有个小问题,在这里和大家分享一下。

是一个计算generalization error的问题。

计算generalization error时,原代码有一行我一直不理解,test_loss /= len(test_loader) # loss function already averages over batch size。试了一下输出,就发现len(test_loader)指的是进行一次data pass,会将所有的数据分割成多少份的mini-batch。然后配合原来的一行代码,在每次mini-batch计算loss的时候,test_loss += F.nll_loss(output, target).data[0],就可以推出原代码在求解generalization error的逻辑是这样的:

  1. 将一个data pass分成几个mini-batch
  2. 每一个mini-batch,F.nll_loss(output, target).data[0]的loss value并不是整个mini-batch的loss,而是average loss,有一个默认size_average=True的参数(后面会用到)
  3. 进行一次data pass之后,就可以将每一个mini-batch average loss求和
  4. 将loss之和再除以mini-batch的数量,就得到最后的data point average loss

所以到这里就能知道,这里有一个隐藏的bug:这里假设了我每一个mini-batch size是一样的,所以才能用这样求平均的方式。但实际上,最后一个mini-batch是很难正好“满上”的。

更为精确求解loss的方法是,每一个mini-batch loss不算平均,而直接求和。最后除以所有data point的个数。大概代码如下:

for each mini-batch:    ...    test_loss += F.nll_loss(output, target, size_average=False).data[0]    ......test_loss /= len(test_loader.dataset)...

第一次给pub repo做pull request,感觉挺不错。PyTorch还有很多坑没填,大家感兴趣的可以慢慢填。我看到很多人直接把原bug照搬到自己的repo,也不知道有没有发现这个问题。其实只要深究test_loss /= len(test_loader)这一行就能发现。

附:因为我pull request已经被merge到code base里,感兴趣的朋友可以移步到github. 类似的问题在mnist_hogwild也有,成功merge github.

原创粉丝点击