Caffe源码:conv_im2col

来源:互联网 发布:java 计算逻辑表达式 编辑:程序博客网 时间:2024/06/07 02:36

写在前头:一天下来到现在什么都没干,下午还来一烦心事,真心感到是浪费时间浪费生命,赶紧写篇博客压压惊。


因为现在在看caffe源码,目前正进行到卷积层的前向传播部分。因此本篇博客着重理解im2col.cpp中的im2col_cpu部分代码,在理解caffe中实现卷积的思想后Conv in caffe,对代码理解进行记录。
可以知道:卷积操作的输入featuremap(包括图像)是一个三维张量(CxHxW),C为通道数,Hfeaturemap的高,Wfeaturemap的宽。卷积核的参数为NxCxKxKN为卷积核的个数,C为通道数即要与输入featuremap一致,K为卷积核的高和宽.卷积的输出featuremap的大小则有卷积操作的参数pad、stride、dilation决定。
e.g:对单一通道的输入featuremap(28*28),pad=0,stride=1,dilation=1.

template <typename Dtype>void im2col_cpu(const Dtype* data_im, const int channels,    const int height, const int width, const int kernel_h, const int kernel_w,    const int pad_h, const int pad_w,    const int stride_h, const int stride_w,    const int dilation_h, const int dilation_w,    Dtype* data_col) {//输入各种参数  行优先储存  const int output_h = (height + 2 * pad_h -    (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;  const int output_w = (width + 2 * pad_w -    (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;//根据公式算出输出窗口的各个维度  const int channel_size = height * width;  for (int channel = channels; channel--; data_im += channel_size) {//依次将对每个通道进行操作     for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {//卷积核中某点的行      for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {//卷积核中某点的列        int input_row = -pad_h + kernel_row * dilation_h;// 卷积核初始卷积时位置中某点的行 用于定位输入图像中的行索引        for (int output_rows = output_h; output_rows; output_rows--) {//[0,24]在输入图像上行向的卷积次数:对于每一行,在下一个循环进行列循环 两个循环的次数即输出矩阵的一行output_h*output_w          if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {//该函数定义是:若a大于0且严格小于b,则返回真,否则返回假,该函数的作用是判断矩阵上某元的输出是否为pad的0            for (int output_cols = output_w; output_cols; output_cols--) {              *(data_col++) = 0;            }          } else {            int input_col = -pad_w + kernel_col * dilation_w;// 卷积核初始卷积时位置中某点的列 用于定位输入图像中的列索引            for (int output_col = output_w; output_col; output_col--) {//[0,24]在输入图像上列向的卷积次数:对于每一列,输出output_w              if (is_a_ge_zero_and_a_lt_b(input_col, width)) {                  *(data_col++) = data_im[input_row * width + input_col];// (input_row,input_col)既定位卷积核中的位置点,又用于索引输入图像的具体位置              } else {                *(data_col++) = 0;              }              input_col += stride_w;//由kernel_col初始化即卷积核中列位置,最多加output_w次,扫输入图像的一行,对输入图像逐渐向右扫  输入图像的列索引            }          }          input_row += stride_h;//由kernel_row初始化即卷积核中行位置,最多加output_h次,扫完输入图像的一行后,更新,对输入图像向下扫 输入图像的行索引         }      }//更新卷积核下一个点(input_row,input_col)    }  }}

这里写图片描述

原创粉丝点击