CRF as RNN的原理及Caffe实现

来源:互联网 发布:国内域名国外空间 编辑:程序博客网 时间:2024/05/01 01:35

CRF(Conditional Random Field)是图像分割中很常用的后处理算法。在《全卷积网络(FCN)与图像分割 》这篇博文中提到,FCN可以得到较好的分割结果,Chen, Liang-Chieh, et al. 2014在其基础上使用fully connected CRF得到了更好的效果,但是FCN的步骤和CRF的步骤是分开的。Zheng et al 2015将fully connected CRF表示成RNN的结构,与FCN连接在一起,可以同时训练FCN和CRF,使分割的准确率有了更多提高。

CRF as RNN的原理

CRF的能量函数包括一个数据项和平滑项。数据项是基于每个像素属于各个类别的概率,平滑项是基于像素之间的灰度值差异和空间距离。传统的CRF的平滑项只考虑相邻像素间的关联,而Fully connected CRF考虑了图像中任意两个像素之间的关联性。

E(x)=iψu(xi)+ijψp(xi,xj) 公式(1)

其中ψu(xi)是数据项, ψp(xi,xj)是能量项, 即像素对之间的能量,其定义为若干个高斯函数的和:

ψp(xi,xj)=μ(xi,xj)Km=1ω(m)k(m)(fi,fj) 公式(2)

高斯函数的定义为:

k(fi,fj)=ω(1)exp(|pipj|22θ2α|IiIj|22θ2β)+ω(2)exp(|pipj|22θ2γ) 公式(3)

k(fi,fj)由两部分组成,第一部分是Bilateral Filter, 其想法是空间上离得近并且灰度值接近的像素很有可能是属于同一个物体。第二部分是Spatial Kernel,起到空间平滑作用,可以去除掉分割结果中一些孤立的小区域。

条件随机场的概率函数为P(X=x|I)=1Zexp(E(x|I))。对公式(1)中的E(x)最小化对应着对后验概率P(X=x | I)的最大化,从而得到最优分割结果。

由于直接计算概率函数P(X)比较麻烦,可以通过一个比较方便计算的概率函数Q(X)来近似得到P(X)。 Q(X)=iQi(Xi)。为了让Q(X)最大限度接近P(X),可通过对它们的KL-divergence最小化得到。这个最小化过程的迭代步骤如下:
这里写图片描述 公式(4)

将该步骤中的各个操作拆分,可以得到如下的算法:
这里写图片描述

该算法的输入为: 初始势能Ui(l),对应公式(1)的第一项; m个滤波器k(m)(fi,fj); 各滤波器的权重ω(m), 不同类别之间的兼容性矩阵μ(l,l)。输出为Q_i即更新后的概率分布。

该算法的每一次迭代分为5个步骤:
1, 信息传递。即使用m个滤波器分别对每一个类别l的概率图Qi(l)进行滤波的过程。
2, 滤波结果加权相加。对每一个类别l的m个滤波结果根据权重ω(m)相加。
3,类别兼容性转换。对每一个类别l的概率图根据不同类别之间的兼容性矩阵μ(l,l)进行更新。
4, 加上数据项(一元项 Unary Potential)。
5, 归一化。对各像素所属不同类别l的概率归一化,这实际上是一个Softmax的过程。

Caffe实现

GitHub上可以找到CRF as RNN的源代码。主要有两个类:MultiStageMeanfieldLayer和MeanfieldIteration。其中MultiStageMeanfieldLayer将CRF的所有迭代步骤组装在一起形成一个Caffe层,迭代的每一步作为一个子层在MeanfieldIteration中实现。目前是基于cpu代码的实现,还没有在cuda上实现。

该实现中考虑了Bilateral filter 和Spatial filter两种滤波器,分别对应公式(3)中的两项。由于Fully connected CRF中的滤波操作要考虑整个图像的信息, 滤波器核的大小为整个图像的尺寸,因此滤波过程比较耗时,为了提高效率,Zheng et al使用了Andrew Adams et al所提出的基于Perutohedral Lttice的滤波实现: 实现Bilateral filter 的bilateral_lattices_和实现Spatial filter的spatial_lattice。

MeanfieldIteration的前向运算为:

/** * Forward pass during the inference. */template <typename Dtype>void MeanfieldIteration<Dtype>::Forward_cpu() {  //------------------------------- Softmax normalization--------------------  softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);  //-----------------------------------Message passing-----------------------  for (int n = 0; n < num_ /*number of images in the batch */; ++n) {  //Spatial filtering    Dtype* spatial_out_data = spatial_out_blob_.mutable_cpu_data() + spatial_out_blob_.offset(n);    const Dtype* prob_input_data = prob_.cpu_data() + prob_.offset(n);    spatial_lattice_->compute(spatial_out_data, prob_input_data, channels_, false);    // Spatial filtering, Pixel-wise normalization.    for (int channel_id = 0; channel_id < channels_; ++channel_id) {      caffe_mul(num_pixels_, spatial_norm_->cpu_data(),          spatial_out_data + channel_id * num_pixels_,          spatial_out_data + channel_id * num_pixels_);    }    // Bilateral filtering    Dtype* bilateral_out_data = bilateral_out_blob_.mutable_cpu_data() + bilateral_out_blob_.offset(n);    (*bilateral_lattices_)[n]->compute(bilateral_out_data, prob_input_data, channels_, false);    // Bilateral filtering, Pixel-wise normalization.    for (int channel_id = 0; channel_id < channels_; ++channel_id) {      caffe_mul(num_pixels_, bilateral_norms_->cpu_data() + bilateral_norms_->offset(n),          bilateral_out_data + channel_id * num_pixels_,          bilateral_out_data + channel_id * num_pixels_);    }  }  caffe_set(count_, Dtype(0.), message_passing_.mutable_cpu_data());  // Add the weighted output of spatial filtering  for (int n = 0; n < num_; ++n) {    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1.,        this->blobs_[0]->cpu_data(), spatial_out_blob_.cpu_data() + spatial_out_blob_.offset(n), (Dtype) 0.,        message_passing_.mutable_cpu_data() + message_passing_.offset(n));  }  // Add the weighted output of bilateral filtering  for (int n = 0; n < num_; ++n) {    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1.,        this->blobs_[1]->cpu_data(), bilateral_out_blob_.cpu_data() + bilateral_out_blob_.offset(n), (Dtype) 1.,        message_passing_.mutable_cpu_data() + message_passing_.offset(n));  }  //--------------------------- Compatibility multiplication ----------------  //Result from message passing needs to be multiplied with compatibility values.  for (int n = 0; n < num_; ++n) {    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_,        channels_, (Dtype) 1., this->blobs_[2]->cpu_data(),        message_passing_.cpu_data() + message_passing_.offset(n), (Dtype) 0.,        pairwise_.mutable_cpu_data() + pairwise_.offset(n));  }  //------------------------- Adding unaries, normalization is left to the next iteration --------------  // Add unary  sum_layer_->Forward(sum_bottom_vec_, sum_top_vec_);}

其中caffe_cpu_gemm是实现C=αAB+βC的运算函数。在Forward_cpu()中依次实现了上面提到的5个步骤。
void MeanfieldIteration::Backward_cpu()中倒序实现了上述个步骤的反向传播。

MultiStageMeanfieldLayer中为每一个迭代步骤分别创建了一个MeanfieldIteration层,其前向传播的代码为:

template <typename Dtype>void MultiStageMeanfieldLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top) {  split_layer_bottom_vec_[0] = bottom[0];  split_layer_->Forward(split_layer_bottom_vec_, split_layer_top_vec_);  // Initialize the bilateral lattices.  bilateral_lattices_.resize(num_);  for (int n = 0; n < num_; ++n) {    compute_bilateral_kernel(bottom[2], n, bilateral_kernel_buffer_.get());    bilateral_lattices_[n].reset(new ModifiedPermutohedral());    bilateral_lattices_[n]->init(bilateral_kernel_buffer_.get(), 5, num_pixels_);    // Calculate bilateral filter normalization factors.    Dtype* norm_output_data = bilateral_norms_.mutable_cpu_data() + bilateral_norms_.offset(n);    bilateral_lattices_[n]->compute(norm_output_data, norm_feed_.get(), 1);    for (int i = 0; i < num_pixels_; ++i) {      norm_output_data[i] = 1.f / (norm_output_data[i] + 1e-20f);    }  }  for (int i = 0; i < num_iterations_; ++i) {    meanfield_iterations_[i]->PrePass(this->blobs_, &bilateral_lattices_, &bilateral_norms_);    meanfield_iterations_[i]->Forward_cpu();  }}
1 0
原创粉丝点击