CRF as RNN的原理及Caffe实现
来源:互联网 发布:国内域名国外空间 编辑:程序博客网 时间:2024/05/01 01:35
CRF(Conditional Random Field)是图像分割中很常用的后处理算法。在《全卷积网络(FCN)与图像分割 》这篇博文中提到,FCN可以得到较好的分割结果,Chen, Liang-Chieh, et al. 2014在其基础上使用fully connected CRF得到了更好的效果,但是FCN的步骤和CRF的步骤是分开的。Zheng et al 2015将fully connected CRF表示成RNN的结构,与FCN连接在一起,可以同时训练FCN和CRF,使分割的准确率有了更多提高。
CRF as RNN的原理
CRF的能量函数包括一个数据项和平滑项。数据项是基于每个像素属于各个类别的概率,平滑项是基于像素之间的灰度值差异和空间距离。传统的CRF的平滑项只考虑相邻像素间的关联,而Fully connected CRF考虑了图像中任意两个像素之间的关联性。
其中
高斯函数的定义为:
条件随机场的概率函数为
由于直接计算概率函数P(X)比较麻烦,可以通过一个比较方便计算的概率函数Q(X)来近似得到P(X)。
公式(4)
将该步骤中的各个操作拆分,可以得到如下的算法:
该算法的输入为: 初始势能
该算法的每一次迭代分为5个步骤:
1, 信息传递。即使用m个滤波器分别对每一个类别l的概率图
2, 滤波结果加权相加。对每一个类别l的m个滤波结果根据权重
3,类别兼容性转换。对每一个类别l的概率图根据不同类别之间的兼容性矩阵
4, 加上数据项(一元项 Unary Potential)。
5, 归一化。对各像素所属不同类别l的概率归一化,这实际上是一个Softmax的过程。
Caffe实现
GitHub上可以找到CRF as RNN的源代码。主要有两个类:MultiStageMeanfieldLayer和MeanfieldIteration。其中MultiStageMeanfieldLayer将CRF的所有迭代步骤组装在一起形成一个Caffe层,迭代的每一步作为一个子层在MeanfieldIteration中实现。目前是基于cpu代码的实现,还没有在cuda上实现。
该实现中考虑了Bilateral filter 和Spatial filter两种滤波器,分别对应公式(3)中的两项。由于Fully connected CRF中的滤波操作要考虑整个图像的信息, 滤波器核的大小为整个图像的尺寸,因此滤波过程比较耗时,为了提高效率,Zheng et al使用了Andrew Adams et al所提出的基于Perutohedral Lttice的滤波实现: 实现Bilateral filter 的bilateral_lattices_和实现Spatial filter的spatial_lattice。
MeanfieldIteration的前向运算为:
/** * Forward pass during the inference. */template <typename Dtype>void MeanfieldIteration<Dtype>::Forward_cpu() { //------------------------------- Softmax normalization-------------------- softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_); //-----------------------------------Message passing----------------------- for (int n = 0; n < num_ /*number of images in the batch */; ++n) { //Spatial filtering Dtype* spatial_out_data = spatial_out_blob_.mutable_cpu_data() + spatial_out_blob_.offset(n); const Dtype* prob_input_data = prob_.cpu_data() + prob_.offset(n); spatial_lattice_->compute(spatial_out_data, prob_input_data, channels_, false); // Spatial filtering, Pixel-wise normalization. for (int channel_id = 0; channel_id < channels_; ++channel_id) { caffe_mul(num_pixels_, spatial_norm_->cpu_data(), spatial_out_data + channel_id * num_pixels_, spatial_out_data + channel_id * num_pixels_); } // Bilateral filtering Dtype* bilateral_out_data = bilateral_out_blob_.mutable_cpu_data() + bilateral_out_blob_.offset(n); (*bilateral_lattices_)[n]->compute(bilateral_out_data, prob_input_data, channels_, false); // Bilateral filtering, Pixel-wise normalization. for (int channel_id = 0; channel_id < channels_; ++channel_id) { caffe_mul(num_pixels_, bilateral_norms_->cpu_data() + bilateral_norms_->offset(n), bilateral_out_data + channel_id * num_pixels_, bilateral_out_data + channel_id * num_pixels_); } } caffe_set(count_, Dtype(0.), message_passing_.mutable_cpu_data()); // Add the weighted output of spatial filtering for (int n = 0; n < num_; ++n) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., this->blobs_[0]->cpu_data(), spatial_out_blob_.cpu_data() + spatial_out_blob_.offset(n), (Dtype) 0., message_passing_.mutable_cpu_data() + message_passing_.offset(n)); } // Add the weighted output of bilateral filtering for (int n = 0; n < num_; ++n) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., this->blobs_[1]->cpu_data(), bilateral_out_blob_.cpu_data() + bilateral_out_blob_.offset(n), (Dtype) 1., message_passing_.mutable_cpu_data() + message_passing_.offset(n)); } //--------------------------- Compatibility multiplication ---------------- //Result from message passing needs to be multiplied with compatibility values. for (int n = 0; n < num_; ++n) { caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels_, num_pixels_, channels_, (Dtype) 1., this->blobs_[2]->cpu_data(), message_passing_.cpu_data() + message_passing_.offset(n), (Dtype) 0., pairwise_.mutable_cpu_data() + pairwise_.offset(n)); } //------------------------- Adding unaries, normalization is left to the next iteration -------------- // Add unary sum_layer_->Forward(sum_bottom_vec_, sum_top_vec_);}
其中caffe_cpu_gemm是实现
void MeanfieldIteration::Backward_cpu()中倒序实现了上述个步骤的反向传播。
MultiStageMeanfieldLayer中为每一个迭代步骤分别创建了一个MeanfieldIteration层,其前向传播的代码为:
template <typename Dtype>void MultiStageMeanfieldLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { split_layer_bottom_vec_[0] = bottom[0]; split_layer_->Forward(split_layer_bottom_vec_, split_layer_top_vec_); // Initialize the bilateral lattices. bilateral_lattices_.resize(num_); for (int n = 0; n < num_; ++n) { compute_bilateral_kernel(bottom[2], n, bilateral_kernel_buffer_.get()); bilateral_lattices_[n].reset(new ModifiedPermutohedral()); bilateral_lattices_[n]->init(bilateral_kernel_buffer_.get(), 5, num_pixels_); // Calculate bilateral filter normalization factors. Dtype* norm_output_data = bilateral_norms_.mutable_cpu_data() + bilateral_norms_.offset(n); bilateral_lattices_[n]->compute(norm_output_data, norm_feed_.get(), 1); for (int i = 0; i < num_pixels_; ++i) { norm_output_data[i] = 1.f / (norm_output_data[i] + 1e-20f); } } for (int i = 0; i < num_iterations_; ++i) { meanfield_iterations_[i]->PrePass(this->blobs_, &bilateral_lattices_, &bilateral_norms_); meanfield_iterations_[i]->Forward_cpu(); }}
- CRF as RNN的原理及Caffe实现
- CRF as RNN 代码解读
- RGB图像语义分割前沿:CRF as RNN
- RNN的原理与TensorFlow代码实现
- RNN及LSTM的matlab实现
- 使用CRF++进行分词的原理和实现过程
- 深度学习(三十三)CRF as RNN语义分割-未完待续
- 深度学习(三十三)CRF as RNN语义分割-未完待续
- RNN递归神经网络的详细推导及C++实现
- TensorFlow (RNN)深度学习下 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 双向LSTM+CRF跑序列标注问题
- triplet loss的原理及caffe代码
- CRF分词的java实现
- CRF的开源实现
- crf的Python实现代码
- 第十一课 tensorflow RNN原理及解析
- TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载
- TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载
- caffe实现RNN(recursive Neural Network, recursive NN)
- 关于unity中BindChannels的理解
- 朴素贝叶斯法
- iOS面试题一
- 六步骤开发和发布自己的Android Studio类库
- angularjs1.0使用总结
- CRF as RNN的原理及Caffe实现
- 详解MySql优化步骤
- ActionListener的三种实现方法
- Oracle切换UNDO空间数据库后存储过程无法正常编译
- 处理一个关于binlog增量恢复很慢的问题
- valgrind 简介
- 文字渐渐显示效果
- redis总结
- 数据结构(王道)【线性表】【算法1.3】