Caffe 代码解读之全连接层concat layer
来源:互联网 发布:会读字的软件 编辑:程序博客网 时间:2024/06/06 11:43
今天,我们看一下caffe的拼接层,即将两个或多个layer进行拼接。
首先,看一下caffe官方文档。
同其他layer一样,分为setup、reshape、Forward_cpu、Backward_cpu。
//concat_layer 用来实现两个或者多个blob的链接,即多输入一输出//支持在num 维度上的链接(concat_dim = 0 : (n1+n2+...+nk)∗c∗h∗w )//和channel维度上的链接(concat_dim = 1 : n∗(c1+c2+...+ck)∗h∗w)。//axis ,dim :0 为 num 维度链接,1 为 channel 维度链接//这里需要给出axis或concat_dimtemplate <typename Dtype>void ConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { const ConcatParameter& concat_param = this->layer_param_.concat_param(); CHECK(!(concat_param.has_axis() && concat_param.has_concat_dim())) << "Either axis or concat_dim should be specified; not both.";}template <typename Dtype>void ConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { //获取axis,确定拼接哪一维度 const int num_axes = bottom[0]->num_axes(); const ConcatParameter& concat_param = this->layer_param_.concat_param(); //以下都在获取、判断axis的维度 if (concat_param.has_concat_dim()) { concat_axis_ = static_cast<int>(concat_param.concat_dim()); // Don't allow negative indexing for concat_dim, a uint32 -- almost // certainly unintended. CHECK_GE(concat_axis_, 0) << "casting concat_dim from uint32 to int32 " << "produced negative result; concat_dim must satisfy " << "0 <= concat_dim < " << kMaxBlobAxes; CHECK_LT(concat_axis_, num_axes) << "concat_dim out of range."; } else { concat_axis_ = bottom[0]->CanonicalAxisIndex(concat_param.axis()); } // Initialize with the first blob. //这里有一点需要解释,可以看到,bottom类型为 vector<Blob<Dtype>*>,这里只需要使用bottom[0] //给shape赋值就好,其实botom本身就是一个Blob的vector //比如:我要将两个layer拼接,那么久有bottom[0]以及bottom[1] vector<int> top_shape = bottom[0]->shape(); //concat_axis_ = 0 : num_concats_=num;concat_axis_ = 1 : num_concats_=num x channel; num_concats_ = bottom[0]->count(0, concat_axis_); //concat_axis_ = 0 : concat_input_size_=channel x height x width; //concat_axis_ = 1 : concat_input_size_=height x width; concat_input_size_ = bottom[0]->count(concat_axis_ + 1); int bottom_count_sum = bottom[0]->count(); //检测num_axes拼接的层是否相同,num_axes为维度信息 for (int i = 1; i < bottom.size(); ++i) { CHECK_EQ(num_axes, bottom[i]->num_axes()) << "All inputs must have the same #axes."; for (int j = 0; j < num_axes; ++j) { if (j == concat_axis_) { continue; } CHECK_EQ(top_shape[j], bottom[i]->shape(j)) << "All inputs must have the same shape, except at concat_axis."; } bottom_count_sum += bottom[i]->count(); top_shape[concat_axis_] += bottom[i]->shape(concat_axis_); } top[0]->Reshape(top_shape); CHECK_EQ(bottom_count_sum, top[0]->count());}
1、这里有一点需要解释,可以看到,bottom类型为 vector blob,这里只需要使用bottom[0]给shape赋值就好,其实bottom本身就是一个Blob的vector。
2、CHECK_**,这里给小白们解释一下,就是判断是否相等、小于、大于
3、 count,这看到有好多的count函数,这些函数在blob层实现,解释如下:
inline int count() const { return count_; } /** * @brief Compute the volume of a slice; i.e., the product of dimensions * among a range of axes. * * @param start_axis The first axis to include in the slice. * * @param end_axis The first axis to exclude from the slice. */ inline int count(int start_axis, int end_axis) const { CHECK_LE(start_axis, end_axis); CHECK_GE(start_axis, 0); CHECK_GE(end_axis, 0); CHECK_LE(start_axis, num_axes()); CHECK_LE(end_axis, num_axes()); int count = 1; for (int i = start_axis; i < end_axis; ++i) { count *= shape(i); } return count; } /** * @brief Compute the volume of a slice spanning from a particular first * axis to the final axis. * * @param start_axis The first axis to include in the slice. */ inline int count(int start_axis) const { return count(start_axis, num_axes()); }
前向传播就是layer的拼接
template <typename Dtype>void ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { Dtype* top_data = top[0]->mutable_cpu_data(); int offset_concat_axis = 0; const int top_concat_axis = top[0]->shape(concat_axis_); //遍历所有输入bottom for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->cpu_data(); const int bottom_concat_axis = bottom[i]->shape(concat_axis_); //把 各个bottom data 拷贝到输出 top data 的对应位置 for (int n = 0; n < num_concats_; ++n) { //case 0:num x channel x h x w;case 1: channel x h x w //case 0:bottom_data + n x num x channel x h x w ; //case 1:bottom_data + n x channel x h x w ; caffe_copy(bottom_concat_axis * concat_input_size_, bottom_data + n * bottom_concat_axis * concat_input_size_, top_data + (n * top_concat_axis + offset_concat_axis) * concat_input_size_); } offset_concat_axis += bottom_concat_axis; }}
反向传播,就是layer层之间diff和data的传播
//反向传播就是对每一个bottom的 diff 做和 data 相同的链接template <typename Dtype>void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) { const Dtype* top_diff = top[0]->cpu_diff(); int offset_concat_axis = 0; const int top_concat_axis = top[0]->shape(concat_axis_); for (int i = 0; i < bottom.size(); ++i) { if (!propagate_down[i]) { continue; } Dtype* bottom_diff = bottom[i]->mutable_cpu_diff(); const int bottom_concat_axis = bottom[i]->shape(concat_axis_); for (int n = 0; n < num_concats_; ++n) { caffe_copy(bottom_concat_axis * concat_input_size_, top_diff + (n * top_concat_axis + offset_concat_axis) * concat_input_size_, bottom_diff + n * bottom_concat_axis * concat_input_size_); } offset_concat_axis += bottom_concat_axis; }}
0 1
- Caffe 代码解读之全连接层concat layer
- Caffe 代码解读之全连接层 inner product layer
- caffe全连接层原理解读
- Caffe 代码解读之 softmax layer
- caffe之(四)全连接层
- [Caffe]: 关于concat layer
- caffe源码 全连接层
- 【Caffe代码解析】Layer网络层
- 【Caffe代码解析】Layer网络层
- caffe的concat层
- concat层 --caffe
- caffe源码解析之Layer层(1)
- caffe源码解析之Layer层(1)
- caffe layer层详解
- Caffe源码阅读(1) 全连接层
- Caffe源码阅读(1) 全连接层
- caffe层解读系列——slice和concat实现MultiTask
- caffe层解读系列——slice和concat实现MultiTask
- Myeclipse中启动tomcat 端口被占用问题
- C#中String与StringBuilder的区别
- 通过代理更新UITableHeaderFooterView某一行数据时,如何获取更新哪一行的
- 这才刚刚开始
- mysql中日期时间型解析
- Caffe 代码解读之全连接层concat layer
- oracle11g64位安装和32位plsql的安装使用
- c# lock (obj) 与 lock (this) 区别
- 能否始终保持如初学者般的热情、专注,决定了在做某件事时能走多远,能做多好。
- 第三周项目4:穷举法解决组合问题
- 剑指offer-复杂链表的复制
- Android Studio如何使用Git提交代码到GitHub和OsChina并解决冲突
- Django实战(20):分页(Pagination)
- Array--List--ArrayList 三者的区别(一)——引言篇