Caffe代码阅读笔记(1)

来源:互联网 发布:jsp端口更改 编辑:程序博客网 时间:2024/04/30 11:53

想看看卷积层的具体实现细节。于是从main函数一路追到卷积,就像执行了N次递归一样,到现在终于递归到最底层函数。

im2col_cpu函数这段代码涉及到卷积层的实现,我读了很久才完全读懂。特记录如下:

template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
    const int height, const int width, const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w,
    const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    Dtype* data_col) {
  const int output_h = (height + 2 * pad_h -
    (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int output_w = (width + 2 * pad_w -
    (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
  const int channel_size = height * width;
  for (int channel = channels; channel--; data_im += channel_size) {
    for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
      for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
        int input_row = -pad_h + kernel_row * dilation_h;
        for (int output_rows = output_h; output_rows; output_rows--) {
          if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
            for (int output_cols = output_w; output_cols; output_cols--) {
              *(data_col++) = 0;
            }
          } else {
            int input_col = -pad_w + kernel_col * dilation_w;
            for (int output_col = output_w; output_col; output_col--) {
              if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
                *(data_col++) = data_im[input_row * width + input_col];
              } else {
                *(data_col++) = 0;
              }
              input_col += stride_w;
            }
          }
          input_row += stride_h;
        }
      }
    }
  }
}

im2col函数的功能是把输入的多个通道的image“展开”成column。此处的column其实并不是真的是一列,而是与二维滤波kernel相对应的多张连续存储的以卷积输出的feature map的宽和高为大小的待参与二维卷积的输入图像数据。其中,滤波kernel的每一个系数对应一个二维输入图像数据的矩阵,这个二维矩阵是考虑到kernel_size、padding_size、stride、dilation这些因素后从输入图像(或输入feature map)中抽取出来的。例如:kernel大小3x3,意味着有9个系数,那么就有9个待参与滤波的输入图像数据二维矩阵。

其中output_w和output_h的计算我想了好久才想通。

以宽度为例。首先看一下几个参数:

kernel_w是指kernel的宽度。一般来说kernel宽度都是奇数(中间的那个点代表当前像素),那么我们可想而知,若不做padding、也先不考虑stride和dilation的话,滤波之后的图像宽度(output_w)应该等于width-(kernel_w-1),即左边去掉(kernel_w-1)/2个像素,右边也去掉(kernel_w-1)/2个像素。若kernel宽度不是奇数,那么左边去掉的像素个数等于kernel模板中当前像素左边的像素个数,那么右边去掉的像素个数等于kernel模板中当前像素右边的像素个数,两者加在一起还是等于kernel_w-1。

pad_w为扩边的像素个数,图像左右都要扩边pad_w个点。从代码来看,padding扩边时是在边上补零而不是像图像处理那样复制或镜像。pad_w一般等于(kernel_w-1)/2,假设kernel_w是奇数的话。

stride_w为相邻两次滤波当前像素的间隔。考虑stride时,输出图像宽度会变得比较复杂,后面详细讲。

dilation_w膨胀系数没有研究过,看代码应该是把滤波窗口放大dilation_w倍。

回过头来看output_w的代码:

const int output_w = (width + 2 * pad_w -  (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

首先考虑stride_w = 1时的情况:左右各扩边pad_w,因此图像宽度+2*pad_w,这个好理解。然后先考虑dilation_w=1的情况,那么公式就简化成了:

const int output_w = (width + 2 * pad_w -  ((kernel_w - 1) + 1))  + 1;

其中减去((kernel_w - 1)+1)之后再加1可以这样理解:一开始减去整个kernel_w,相当于多减了一个1,所以后面再加回来。

然后考虑stride_w大于1的情况:在除以stride_w前,width减去  (dilation_w * (kernel_w - 1) + 1)时仍然是多减了一个1(dilation相当于把kernel中除当前像素之外的宽度放大了N倍,放大之后再加1即等于膨胀之后的kernel窗口宽度)。但这不要紧:假设width' = (width + 2 * pad_w -  (dilation_w * (kernel_w - 1) + 1)),即分子部分。case1: 若width'除以stride_w能除尽,则相当于width'+1(即把那个多减掉的1加回来)除以stride_w的余数是1,表示跳过stride的最后一个像素点,且这个点属于滤波后的结果,因此width'/stride_w + 1正好等于跳点滤波后的输出像素点个数;case2: 若width' 除以stride_w除不尽,则说明跳过最后一个stride后至少还存在一个滤波后的像素点(不管是否把那个多减掉的1加回来都是一样),剩余的点由于会被stride跳过所有都是无效点,因此width'/stride_w + 1也等于跳点滤波后的输出像素点个数。

至此,output_w/output_h的计算公式就能完全解释清楚了。

最后再看一下这几层嵌套的for循环所代表的顺序:

最外层的循环表示输入数据的通道数。中间两层循环是以kernel_w和kernel_h为循环次数。最里面的两层循环是以output_w/output_h为循环次数。

为什么是这样的顺序?我的理解是:做卷积运算时,卷积模板以滑动窗的形式扫过输入图像的每一个像素,卷积模板的每个系数相对于模板中央(当前像素所在的位置)的偏移不同,则与其相乘的像素点在图像中位置偏移也不同。此代码的思想是,把卷积模板中每一个系数在做卷积运算时将会与之相乘的像素点从输入图像中抽取出来(考虑补边、stride、膨胀等不同参数)、以输出图像的尺寸组织在一起存成一个矩阵,有多少个系数就根据系数在卷积模板中的位置偏移抽取出多少个输入像素数据矩阵。当后续需要进行卷积滤波时,只需要同时从这些输入像素数据矩阵取一个数出来分别与对应的系数相乘即可。此种方法虽然会导致输入图像数据被复制很多份、导致内存占用较大,但也便于SIMD指令或GPU并行计算(对应同一个系数位置的多个数据同时与这个系数相乘)。

关于“抽取”的举例说明:

代码中:

int input_col = -pad_w + kernel_col * dilation_w; //表示的是加入系数位置偏移后的水平起始位置,若起始位置小于0或大于图像宽度,则表示位置在图像之外,故取值为0;否则即抽取对应位置的输入像素数据。

input_col += stride_w; //表示跳过一个stride,得到下一次抽取的位置。


至此,整段代码应该解释得比较清楚了。特记笔记于此以备忘。

0 0
原创粉丝点击