Softmax层解析

来源:互联网 发布:飞狐交易师数据下载 编辑:程序博客网 时间:2024/06/05 04:28

这里我们简单介绍一下Caffe是如何实现Softmax层的,通常我们使用的是SoftmaxWithLossLayer,这里我们仅仅讲讲Caffe的SoftmaxLayer

定义输入

在Caffe的世界里,每一层的输入叫做Bottom,输出叫做Top,而Caffe的Forward就是通过Bottom计算Top的过程,而Backward这是通过Top_diff计算Bottom_diff的过程。

我们定义Bottom:x = {xi } (i=1 … n),Top:z = {zi} (i=1…n)
根据Softmax的公式eijej,我们可以得到:
zi=exinjexj
根据公式,我们可以通过Bottom计算出Top,具体的可以参考Caffe的SoftmaxLayer的forward_cpu的源代码,这里就赶紧进入Backward环节喽~

对于Backward,我们知道Top_diff,它在数学意义上就是Lossz,而同理Bottom_diff,在数学意义上是Lossx,而Backward的目的就是计算Bottom_diff

Caffe源码分析

Caffe源代码比较清晰,主要就是三步骤:
第一步,计算eq1 = dot(top_diff, top_data):

 scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,          bottom_diff + i * dim + k, inner_num_,          top_data + i * dim + k, inner_num_);

第二步,计算eq2 = top_diff - eq1:

caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,        -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);

第三步,计算eq2 * top_data:

caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);

公式推导

那么大家有木有想过为什么bottom_diff = (top_diff - dot(top_diff, top_data))*top_data?
我们来手把手推推公式就了解了
这里写图片描述
全部Backward代码如下:

template <typename Dtype>void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,    const vector<bool>& propagate_down,    const vector<Blob<Dtype>*>& bottom) {  const Dtype* top_diff = top[0]->cpu_diff();  const Dtype* top_data = top[0]->cpu_data();  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();  Dtype* scale_data = scale_.mutable_cpu_data();  int channels = top[0]->shape(softmax_axis_);  int dim = top[0]->count() / outer_num_;  caffe_copy(top[0]->count(), top_diff, bottom_diff);  for (int i = 0; i < outer_num_; ++i) {    // compute dot(top_diff, top_data) and subtract them from the bottom diff    for (int k = 0; k < inner_num_; ++k) {      scale_data[k] = caffe_cpu_strided_dot<Dtype>(channels,          bottom_diff + i * dim + k, inner_num_,          top_data + i * dim + k, inner_num_);    }    // subtraction    caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, channels, inner_num_, 1,        -1., sum_multiplier_.cpu_data(), scale_data, 1., bottom_diff + i * dim);  }  // elementwise multiplication  caffe_mul(top[0]->count(), bottom_diff, top_data, bottom_diff);}
原创粉丝点击