Distilling the Knowledge in a Neural Network 阅读笔记

来源:互联网 发布:c语言写的小软件代码 编辑:程序博客网 时间:2024/05/16 13:50
《Distilling the Knowledge in a Neural Network 》
模型蒸馏论文笔记:
参考资料:http://blog.csdn.net/zhongshaoyy/article/details/53582048
Cumbersome:笨重,巨大
(这里只是考虑分类)对于一个cumbersome的网络来说,目标是区分出大量的类别的目标。常见的训练目标就是最大化average log probability of the correct answer。但是,一个副作用是模型对于所有的不正确的分类结果都会赋予一个概率,即使这些概率非常小,有的还非常大。


当从一个大模型进行蒸馏得到一个小模型的时候,我们可以以同样的方式对小模型进行训练,使其有和大模型一样的泛化能力。(继承了泛化能力)
当一个大模型是多个小模型的集成的时候,比如是a large ensemble of different models的平均,这个模型的泛化能力通常就比较好,当我们以同样的训练方式从中蒸馏得到一个小的网络,与直接在原来的训练数据集上训练的一个小网络相比,通过蒸馏方式得到的小网络在测试集上的性能比后者更好。


我们将复杂模型转化为小模型时需s要注意保留模型的泛化能力,一种方法是利用由复杂模型产生的分类概率作为“软目标”来训练小模型。在转化阶段,我们可以用同样的训练集或者是另外的“转化”训练集。the same training set or separate “transfer” set.


当复杂模型是由简单模型集成而成时,我们可以用各自的概率分布的算数或者几何平均数作为“软目标”。 individual predictive distributions。
当soft targets有很高的entropy熵时,代表其具有很大的不确定度,反应在概率上就是某实例的类别概率比较分散,这样这些Soft targets就能比hard targets提供更大的信息,并且在训练的时候,between training cases 训练的cases之间(理解为不同的batch之间,或者不同的实例之间计算得到的梯度,概率分布越散,样本之间的差异越小),这样小模型就可以依赖更少的数据进行训练,并且可以设置更大的学习率。
 像MNIST这种任务,复杂模型可以给出很完美的结果,大部分信息分布在小概率的软目标中。
Caruana用logits(softmax层的输入)而不是softmax层的输出作为“软目标”。他们目标是是的复杂模型和小模型分别得到的logits的平方差最小。
Caruana 提出的方法:(感觉和stcaking集成学习方法有点类似)
For this transfer stage, we could use the same training set or a separate “transfer” set 


本文提出的方法:
temperature :参数T
第一步,提升softmax表达式中的调节参数T,使得复杂模型产生一个合适的“软目标”(suitably soft set of targets)  第二步,采用同样的T来训练小模型,使得它产生相匹配的“软目标”。
文中称Caruana 提出的匹配cumbersome的logits输出是一个特殊的蒸馏。
The transfer set that is used to train the small model could consist entirely of unlabeled data [1]or we could use the original training set:用于训练小模型,进行知识迁移的数据可以由无标签数据或者原始的训练集组成。
转化”训练集可以由未打标签的数据组成,也可以用原训练集。我们发现使用原训练集效果很好,特别是我们在目标函数中加了一项之后。这一项的目的是是的小模型在预测实际目标的同时尽量匹配“软目标”。
注意:小模型并不能完全无误的匹配“软目标”,而对于正确结果的出现错误的判定是有帮助的。
神经网络通常使用softmax函数预测类别概率,计算如下:


T就是参数:temperature。




MATLAB上测试了一下T值对概率分布的影响:
a=[10,10,11,15,12,13,12,11,11,9];
class=[1,2,3,4,5,6,7,8,9,10];
sum_a1=sum(exp(a));
sum_a2=sum(exp(a/10));
y1=exp(a)./sum_a1;
y2=exp(a/10)./sum_a2;
figure(1);
subplot(211);
plot(class,y1)
title("T==1")
xlabel("类别")
ylabel("softmax输出")
subplot(212);
plot(class,y2)
title("T==10")
xlabel("类别")
ylabel("softmax输出")
可以看出,当T值大于1时,概率分布越平滑。(我感觉这很像机器学习中的stacking集成学习思想,平滑概率分布曲线,相当于使基分类器的差距变大)


最简单的蒸馏形式:
a transfer set 经过with a high temperature in its softmax 处理,变成soft target distribution 








目标函数1:The first objective function is the cross entropy with the soft targets and this cross entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model. 
目标函数2:The second objective function is the cross entropy with the correct labels. This is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1. 


   对两个目标函数进行加权效果最好,We found that the best results were generally obtained by using a condiderably lower weight on the second objective function. 对第二个目标函数赋予相对更低的权值效果最好。


最终目标函数 = 较大权重 x目标函数1 + 较小权重 x目标函数2
论文中给出了交叉熵对Logits的求导公式:


现做一个推导:


上图表示了交叉熵的计算。假设输出单元为n


交叉熵:
计算偏导:
这里把C分成n项相加,依次计算各项对的偏导,然后最终的偏导等于各项相加。(也可以采取复合函数的链式求导法则计算)
第一步:当j==i时。


第二步:当j!=i时。




然后将两式相加可得:








此外,论文给出了证明,当T取值很大的时候,蒸馏优化的目标等价于Caruana提取的对logits的平方误差求最优化。Caruana的方法在于Matching logits ,论文中的2.1小节给出了Matching logits is a special case of distillation 的证明。


论文也讨论了T值对模型蒸馏的影响,此外,
一方面,这些很小的概率对应的logits对于原始的复杂模型来说是unconstrained,因为交叉熵的计算对于真实的分布标签是做了一个one-hot 编码,即只有真实标签对应的类别的logits受到了约束。因此这些logits可能非常noisy。但是另一方面来说,这些非常小的logits可能携带由cumbersome model学习到的知识,还是有点用的。
Which of these effects dominates is an empirical question. 哪个占据主导,这是一个经验的问题。
我们显示,当蒸馏模型太小而不能在繁琐的模型中捕获所有的知识时,中间温度最好地工作,这强烈地表明忽略大的负对数可能是有帮助的。

然后剩下的部分就是实验了,这里就不再讲了。


文章的公式和图片没有显示,可以从这里下载word文档。

链接: https://pan.baidu.com/s/1i5xiKV7 密码: 56mj

阅读全文
0 0
原创粉丝点击