Deep CORAL: Correlation Alignment for Deep Domain Adaptation(2016)

来源:互联网 发布:2017最污的网络词 编辑:程序博客网 时间:2024/06/05 15:36

Introduction

  • 作者引进了一个叫做CORAL的方法,通过对source domain和target domain进行线性变换来将他们各自的的二阶统计量对齐(对齐均值与协方差矩阵)。
  • 但是作者认为CORAL实际上需要先对图像进行特征提取,再进行线性转换,再训练一个SVM分类器,比较麻烦。
  • 作者对CORAL方法进行了拓展,通过在source domain和target domain之间建立一个最小化source domain和target domain数据之间相关性的损失函数来将它并入deep networks。
  • 作者将CORAL方法改进做非线性变换,并且直接作用于source domain和target domain。

Deep CORAL

  • 假设:target domain上没有标记的数据。
  • 第一个目标是:平衡一个大的具有很好泛化的数据域(a large generic domain,比如ImageNet)和source domain之间的deep feature(深度网络学到的特征)。可以通过将神经网络的参数用那个large generic domain预训练过的网络的参数初始化并微调来达到目的。
  • 第二个目标是:最小化source domain和target domain的deep feature的二阶统计量的差别。
  • 网络架构:

    在fc8这一层当中,作者引入了CORAL loss这一损失函数。(AlexNet架构)

CORAL loss

  • source domain:

    • 数据: DS={xi} xRd
    • 标签: LS={yi} i{1,...,L}
    •  nS个数据,维度为 d(像素数什么的)
  • target domain:

    • 数据: DT={ui} uRd
    •  nT个数据,维度为 d(像素数什么的)
  •  DijS(DijT)表示第 j维度下source(target) domain数据中的第 i个样本。用 CS(CT) 表示特征的协方差矩阵。

  • CORAL loss:

    lCORAL=CSCT2F4d2

    • 后一个是矩阵的Frobenius范数
  • 协方差矩阵计算:
    •  CS=1nS1(DTSDS(lTDS)T(lTDS)nS)
    •  CT=1nT1(DTTDT(lTDT)T(lTDT)nT)
    • 其中 l是一个所有元素为1的列向量
  • gradient:
    •  lCORALDijS=1d2(nS1)(DTS((lTDS)TlT)T(CSCT)nS)ij
    •  lCORALDijT=1d2(nS1)(DTT((lTDT)TlT)T(CSCT)nS)ij
    • 论文作者提到说这里使用的是批次协方差(? batch covariances)

End-to-end Domain Adaptation with CORAL Loss

  • 在减少分类损失(classification loss)的同时,引入CORAL loss作为正则项来减少过拟合的可能性(使得训练出来的结果invariant to the difference between source and target domain)
  • loss function:
    l=lCLASS+i=1tλilCORAL

    •  t是CORAL loss 层的数目
    •  λ是用于平衡分类准确度和域适应的一个参数(就是让 lCLASS lCORAL都不要太大)
    • 以上两个参数相互对抗最终达到一个平衡(最终的feature可以在target domain工作的很好)。

结论

  • 利用CORAL loss可以限制source domain和target domain之间数据的距离。(这个和利用核函数的MMD很像)
  • 一种新颖的正则约束项
阅读全文
0 0
原创粉丝点击