代码阅读:R-FCN
来源:互联网 发布:淘宝实名认证手机 编辑:程序博客网 时间:2024/05/17 05:01
国际惯例:https://github.com/daijifeng001/R-FCN
这个matlab版本的代码,RPN是弄好的只是训练position-senstive RoI pooling那一块。我们也主要看着一块
prototxt
#--------------position sensitive RoI pooling--------------layer { bottom: "rfcn_cls" bottom: "rois" top: "psroipooled_cls_rois" name: "psroipooled_cls_rois" type: "PSROIPooling" psroi_pooling_param { spatial_scale: 0.0625 output_dim: 21 group_size: 7 }}layer { bottom: "psroipooled_cls_rois" top: "cls_score" name: "ave_cls_score_rois" type: "Pooling" pooling_param { pool: AVE kernel_size: 7 stride: 7 }}layer { bottom: "rfcn_bbox" bottom: "rois" top: "psroipooled_loc_rois" name: "psroipooled_loc_rois" type: "PSROIPooling" psroi_pooling_param { spatial_scale: 0.0625 output_dim: 8 group_size: 7 }}layer { bottom: "psroipooled_loc_rois" top: "bbox_pred" name: "ave_bbox_pred_rois" type: "Pooling" pooling_param { pool: AVE kernel_size: 7 stride: 7 }}
PSROIPooling
这是作者自己加的一种pooling方法,我们来看怎么实现的吧。
caffe.proto 里添加了
头文件,这个没啥好说的
template <typename Dtype> class PSROIPoolingLayer : public Layer<Dtype> { public: explicit PSROIPoolingLayer(const LayerParameter& param) : Layer<Dtype>(param) {} virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual inline const char* type() const { return "PSROIPooling"; } virtual inline int MinBottomBlobs() const { return 2; } virtual inline int MaxBottomBlobs() const { return 2; } virtual inline int MinTopBlobs() const { return 1; } virtual inline int MaxTopBlobs() const { return 1; } protected: virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top); virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom); Dtype spatial_scale_; int output_dim_; int group_size_; int channels_; int height_; int width_; int pooled_height_; int pooled_width_; Blob<int> mapping_channel_; };
Forward_gpu:
template <typename Dtype> void PSROIPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* bottom_data = bottom[0]->gpu_data(); //获取图片数据 const Dtype* bottom_rois = bottom[1]->gpu_data(); //获取训练好的roi信息 Dtype* top_data = top[0]->mutable_gpu_data();//获取top_data的指针 int* mapping_channel_ptr = mapping_channel_.mutable_gpu_data(); //获取mapping_channel_指针,用以记录channel对应关系 int count = top[0]->count();//top的计数: 256×21×7×7 256是mini-batch的大小 caffe_gpu_set(count, Dtype(0), top_data); caffe_gpu_set(count, -1, mapping_channel_ptr); // NOLINT_NEXT_LINE(whitespace/operators) PSROIPoolingForward<Dtype> << <CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS >> >(count, bottom_data, spatial_scale_, channels_, height_, width_, pooled_height_, pooled_width_, bottom_rois, output_dim_, group_size_, top_data, mapping_channel_ptr); //调用PSROIPoolingForward CUDA_POST_KERNEL_CHECK; }
PSROIPoolingForward:
template <typename Dtype> __global__ void PSROIPoolingForward( const int nthreads, const Dtype* bottom_data, const Dtype spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const Dtype* bottom_rois, const int output_dim, const int group_size, Dtype* top_data, int* mapping_channel) { CUDA_KERNEL_LOOP(index, nthreads) { // The output is in order (n, ctop, ph, pw) //第n个roi,第c个类别,第(i,j)个类别的下标为:index=n×c×7×7+c×7×7+i×7+j(线程id与bin的标号对应) int pw = index % pooled_width;//对7取余,为j,就是bin的横坐标 int ph = (index / pooled_width) % pooled_height;//i,bin的纵坐标 int ctop = (index / pooled_width / pooled_height) % output_dim;//c,第几个类 int n = index / pooled_width / pooled_height / output_dim;//n,第几个roi // [start, end) interval for spatial sampling bottom_rois += n * 5; //获取roi的参数。(batch_index,x1,y1,x2,y2) int roi_batch_ind = bottom_rois[0]; //计算坐标对应到feature map上的坐标 spatial_scale为0.0625 Dtype roi_start_w = static_cast<Dtype>(round(bottom_rois[1])) * spatial_scale; Dtype roi_start_h = static_cast<Dtype>(round(bottom_rois[2])) * spatial_scale; Dtype roi_end_w = static_cast<Dtype>(round(bottom_rois[3]) + 1.) * spatial_scale; Dtype roi_end_h = static_cast<Dtype>(round(bottom_rois[4]) + 1.) * spatial_scale; // Force too small ROIs to be 1x1 Dtype roi_width = max(roi_end_w - roi_start_w, 0.1); // avoid 0 Dtype roi_height = max(roi_end_h - roi_start_h, 0.1); // 计算bin的长和宽 Dtype bin_size_h = roi_height / static_cast<Dtype>(pooled_height); Dtype bin_size_w = roi_width / static_cast<Dtype>(pooled_width); //该bin的起始和重点坐标 int hstart = floor(static_cast<Dtype>(ph) * bin_size_h + roi_start_h); int wstart = floor(static_cast<Dtype>(pw)* bin_size_w + roi_start_w); int hend = ceil(static_cast<Dtype>(ph + 1) * bin_size_h + roi_start_h); int wend = ceil(static_cast<Dtype>(pw + 1) * bin_size_w + roi_start_w); // Add roi offsets and clip to input boundaries hstart = min(max(hstart, 0), height); hend = min(max(hend, 0), height); wstart = min(max(wstart, 0), width); wend = min(max(wend, 0), width); bool is_empty = (hend <= hstart) || (wend <= wstart); int gw = pw; int gh = ph; //ctop*7*7+gh*7+gw int c = (ctop*group_size + gh)*group_size + gw; //data指针移动到位置处 bottom_data += (roi_batch_ind * channels + c) * height * width; Dtype out_sum = 0; //bin求和 for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int bottom_index = h*width + w; out_sum += bottom_data[bottom_index]; } } //bin面积 Dtype bin_area = (hend - hstart)*(wend - wstart); //如果不是empty,就做个average pooling top_data[index] = is_empty? 0. : out_sum/bin_area; //记录下处理的index开始位置,方便以后用 mapping_channel[index] = c; } }
阅读全文
0 0
- 代码阅读:R-FCN
- R-FCN阅读笔记
- R-CNN系列阅读笔记(5): R-FCN
- R-FCN算法及Caffe代码详解
- r-fcn
- R-FCN
- R-FCN
- R-FCN
- R-FCN
- 我读R-FCN
- R-FCN源代码解读
- R-FCN网络
- r-fcn论文
- R-FCN论文翻译
- R-FCN解读
- 【转】R-FCN
- FCN-阅读笔记-理解
- [论文阅读]R-FCN: Object Detection via Region-based Fully Convolutional Networks
- JEESZ-Redis分布式缓存安装和使用
- JavaWeb+Ueditor上传图片到项目外资源文件
- 递归与分治策略-2.11循环赛日程表
- gdb的多线程调试
- 浅谈JVM--《深入理解Java虚拟机》小小总结
- 代码阅读:R-FCN
- DataOutputStream 类 和DatainputStream类 的主要方法简单介绍,及代码演示(转)
- 681
- c++之static
- Vue.js实现可编辑的表格
- 解读今年的 Google IO 2017
- SQL查询50例
- easyui 自定义表单内容验证(汉字、字母、数字、邮箱、电话、邮编、身份证号等)
- OkHttp3的使用