蒸馏神经网络(Distill the Knowledge in a Neural Network)

来源:互联网 发布:2016网络最热门的游戏 编辑:程序博客网 时间:2024/04/29 19:53

本文是阅读Hinton 大神在2014年NIPS上一篇论文:蒸馏神经网络的笔记,特此说明。此文读起来很抽象,大篇的论述,鲜有公式和图表。但是鉴于和我的研究方向:神经网络的压缩十分相关,因此决定花气力好好理解一下。 


1、Introduction  

文章开篇用一个比喻来引入网络蒸馏:

    昆虫作为幼虫时擅于从环境中汲取能量,但是成长为成虫后确是擅于其他方面,比如迁徙和繁殖等。

同理神经网络训练阶段从大量数据中获取网络模型,训练阶段可以利用大量的计算资源且不需要实时响应。然而到达使用阶段,神经网络需要面临更加严格的要求包括计算资源限制,计算速度要求等等。

由昆虫的例子我们可以这样理解神经网络:一个复杂的网络结构模型是若干个单独模型组成的集合,或者是一些很强的约束条件下(比如dropout率很高)训练得到的一个很大的网络模型。一旦复杂网络模型训练完成,我们便可以用另一种训练方法:“蒸馏”,把我们需要配置在应用端的缩小模型从复杂模型中提取出来。

      “蒸馏”的难点在于如何缩减网络结构但是把网络中的知识保留下来。知识就是一幅将输入向量导引至输出向量的地图。做复杂网络的训练时,目标是将正确答案的概率最大化,但这引入了一个副作用:这种网络为所有错误答案分配了概率,即使这些概率非常小。 

      我们将复杂模型转化为小模型时需要注意保留模型的泛化能力,一种方法是利用由复杂模型产生的分类概率作为“软目标”来训练小模型。在转化阶段,我们可以用同样的训练集或者是另外的“转化”训练集。当复杂模型是由简单模型复合而成时,我们可以用各自的概率分布的代数或者几何平均数作为“软目标”。当“软目标的”熵值较高时,相对“硬目标”,它每次训练可以提供更多的信息和更小的梯度方差,因此小模型可以用更少的数据和更高的学习率进行训练。 

    像MNIST这种任务,复杂模型可以给出很完美的结果,大部分信息分布在小概率的软目标中。比如一张2的图片被认为是3的概率为0.000001,被认为是7的概率是0.000000001。Caruana用logits(softmax层的输入)而不是softmax层的输出作为“软目标”。他们目标是是的复杂模型和小模型分别得到的logits的平方差最小。而我们的“蒸馏法”:第一步,提升softmax表达式中的调节参数T,使得复杂模型产生一个合适的“软目标”  第二步,采用同样的T来训练小模型,使得它产生相匹配的“软目标”

   “转化”训练集可以由未打标签的数据组成,也可以用原训练集。我们发现使用原训练集效果很好,特别是我们在目标函数中加了一项之后。这一项的目的是是的小模型在预测实际目标的同时尽量匹配“软目标”。要注意的是,小模型并不能完全无误的匹配“软目标”,而正确结果的犯错方向是有帮助的。

2、Distillation 


    softmax层的公式如下: 


                                       

      T就是调节参数,一般设为1。T越大,分类的概率分布越“软” 

   “蒸馏”最简单的形式就是:以从复杂模型得到的“软目标”为目标(这时T比较大),用“转化”训练集训练小模型。训练小模型时T不变仍然较大,训练完之后T改为1。 

   当“转化”训练集中部分或者所有数据都有标签时,这种方式可以通过一起训练模型使得模型得到正确的标签来大大提升效果。一种实现方法是用正确标签来修正“软目标”,但是我们发现一种更好的方法是:对两个目标函数设置权重系数。第一个目标函数是“软目标”的交叉熵,这个交叉熵用开始的那个比较大的T来计算。第二个目标函数是正确标签的交叉熵,这个交叉熵用小模型softmax层的logits来计算且T等于1。我们发现当第二个目标函数权重较低时可以得到最好的结果 


   

3、Preliminary experiments on MNIST 


  我的理解:将迁移数据集中的3或者7、8去掉是为了证明小模型也能够从soft target中学得知识。 

4、Experiments on Speech Recognition 


5、Training ensembles of specialists on very big datasets 








1 0