主要定义了一个 ROIPoolingLayer 类
ROIPoolingLayer的受保护数据成员有:
int channels_;int height_;int width_;int pooled_height_; int pooled_width_; Dtype spatial_scale_; Blob<int> max_idx_;
LayerSetUp:
template <typename Dtype>void ROIPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { ROIPoolingParameter roi_pool_param = this->layer_param_.roi_pooling_param(); CHECK_GT(roi_pool_param.pooled_h(), 0) << "pooled_h must be > 0"; CHECK_GT(roi_pool_param.pooled_w(), 0) << "pooled_w must be > 0"; pooled_height_ = roi_pool_param.pooled_h(); pooled_width_ = roi_pool_param.pooled_w(); spatial_scale_ = roi_pool_param.spatial_scale(); LOG(INFO) << "Spatial scale: " << spatial_scale_;}
Reshape:
template <typename Dtype>void ROIPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { channels_ = bottom[0]->channels(); height_ = bottom[0]->height(); width_ = bottom[0]->width(); top[0]->Reshape(bottom[1]->num(), channels_, pooled_height_, pooled_width_); max_idx_.Reshape(bottom[1]->num(), channels_, pooled_height_, pooled_width_);}
Forward_cpu:
template <typename Dtype>void ROIPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const Dtype* bottom_data = bottom[0]->cpu_data(); const Dtype* bottom_rois = bottom[1]->cpu_data(); int num_rois = bottom[1]->num(); int batch_size = bottom[0]->num(); int top_count = top[0]->count(); Dtype* top_data = top[0]->mutable_cpu_data(); caffe_set(top_count, Dtype(-FLT_MAX), top_data); int* argmax_data = max_idx_.mutable_cpu_data(); caffe_set(top_count, -1, argmax_data); for (int n = 0; n < num_rois; ++n) { int roi_batch_ind = bottom_rois[0]; int roi_start_w = round(bottom_rois[1] * spatial_scale_); int roi_start_h = round(bottom_rois[2] * spatial_scale_); int roi_end_w = round(bottom_rois[3] * spatial_scale_); int roi_end_h = round(bottom_rois[4] * spatial_scale_); CHECK_GE(roi_batch_ind, 0); CHECK_LT(roi_batch_ind, batch_size); int roi_height = max(roi_end_h - roi_start_h + 1, 1); int roi_width = max(roi_end_w - roi_start_w + 1, 1); const Dtype bin_size_h = static_cast<Dtype>(roi_height) / static_cast<Dtype>(pooled_height_); const Dtype bin_size_w = static_cast<Dtype>(roi_width) / static_cast<Dtype>(pooled_width_); const Dtype* batch_data = bottom_data + bottom[0]->offset(roi_batch_ind); for (int c = 0; c < channels_; ++c) { for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { int hstart = static_cast<int>(floor(static_cast<Dtype>(ph) * bin_size_h)); int wstart = static_cast<int>(floor(static_cast<Dtype>(pw) * bin_size_w)); int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1) * bin_size_h)); int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1) * bin_size_w)); hstart = min(max(hstart + roi_start_h, 0), height_); hend = min(max(hend + roi_start_h, 0), height_); wstart = min(max(wstart + roi_start_w, 0), width_); wend = min(max(wend + roi_start_w, 0), width_); bool is_empty = (hend <= hstart) || (wend <= wstart); const int pool_index = ph * pooled_width_ + pw; if (is_empty) { top_data[pool_index] = 0; argmax_data[pool_index] = -1; } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { const int index = h * width_ + w; if (batch_data[index] > top_data[pool_index]) { top_data[pool_index] = batch_data[index]; argmax_data[pool_index] = index; } } } } } batch_data += bottom[0]->offset(0, 1); top_data += top[0]->offset(0, 1); argmax_data += max_idx_.offset(0, 1); } bottom_rois += bottom[1]->offset(1); }}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
Backward_gpu:
template <typename Dtype>void ROIPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { if (!propagate_down[0]) { return; } const Dtype* bottom_rois = bottom[1]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); const int count = bottom[0]->count(); caffe_gpu_set(count, Dtype(0.), bottom_diff); const int* argmax_data = max_idx_.gpu_data(); ROIPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( count, top_diff, argmax_data, top[0]->num(), spatial_scale_, channels_, height_, width_, pooled_height_, pooled_width_, bottom_diff, bottom_rois); CUDA_POST_KERNEL_CHECK;}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
template <typename Dtype>__global__ void ROIPoolBackward(const int nthreads, const Dtype* top_diff, const int* argmax_data, const int num_rois, const Dtype spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, Dtype* bottom_diff, const Dtype* bottom_rois) { CUDA_KERNEL_LOOP(index, nthreads) { int w = index % width; int h = (index / width) % height; int c = (index / width / height) % channels; int n = index / width / height / channels; Dtype gradient = 0; for (int roi_n = 0; roi_n < num_rois; ++roi_n) { const Dtype* offset_bottom_rois = bottom_rois + roi_n * 5; int roi_batch_ind = offset_bottom_rois[0]; if (n != roi_batch_ind) { continue; } int roi_start_w = round(offset_bottom_rois[1] * spatial_scale); int roi_start_h = round(offset_bottom_rois[2] * spatial_scale); int roi_end_w = round(offset_bottom_rois[3] * spatial_scale); int roi_end_h = round(offset_bottom_rois[4] * spatial_scale); const bool in_roi = (w >= roi_start_w && w <= roi_end_w && h >= roi_start_h && h <= roi_end_h); if (!in_roi) { continue; } int offset = (roi_n * channels + c) * pooled_height * pooled_width; const Dtype* offset_top_diff = top_diff + offset; const int* offset_argmax_data = argmax_data + offset; int roi_width = max(roi_end_w - roi_start_w + 1, 1); int roi_height = max(roi_end_h - roi_start_h + 1, 1); Dtype bin_size_h = static_cast<Dtype>(roi_height) / static_cast<Dtype>(pooled_height); Dtype bin_size_w = static_cast<Dtype>(roi_width) / static_cast<Dtype>(pooled_width); int phstart = floor(static_cast<Dtype>(h - roi_start_h) / bin_size_h); int phend = ceil(static_cast<Dtype>(h - roi_start_h + 1) / bin_size_h); int pwstart = floor(static_cast<Dtype>(w - roi_start_w) / bin_size_w); int pwend = ceil(static_cast<Dtype>(w - roi_start_w + 1) / bin_size_w); phstart = min(max(phstart, 0), pooled_height); phend = min(max(phend, 0), pooled_height); pwstart = min(max(pwstart, 0), pooled_width); pwend = min(max(pwend, 0), pooled_width); for (int ph = phstart; ph < phend; ++ph) { for (int pw = pwstart; pw < pwend; ++pw) { if (offset_argmax_data[ph * pooled_width + pw] == (h * width + w)) { gradient += offset_top_diff[ph * pooled_width + pw]; } } } } bottom_diff[index] = gradient; }}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75