DeepLearning-Xavier在caffe中的实现

来源:互联网 发布:大唐电信数据所地址 编辑:程序博客网 时间:2024/05/16 19:22

我们来看一下caffe中具体是怎样实现的,代码位于include/caffe/filler.hpp文件中。

template <typename Dtype>class XavierFiller : public Filler<Dtype> { public:  explicit XavierFiller(const FillerParameter& param)      : Filler<Dtype>(param) {}  virtual void Fill(Blob<Dtype>* blob) {    CHECK(blob->count());    int fan_in = blob->count() / blob->num();    int fan_out = blob->count() / blob->channels();    Dtype n = fan_in;  // default to fan_in    if (this->filler_param_.variance_norm() ==        FillerParameter_VarianceNorm_AVERAGE) {      n = (fan_in + fan_out) / Dtype(2);    } else if (this->filler_param_.variance_norm() ==        FillerParameter_VarianceNorm_FAN_OUT) {      n = fan_out;    }    Dtype scale = sqrt(Dtype(3) / n);    caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,        blob->mutable_cpu_data());    CHECK_EQ(this->filler_param_.sparse(), -1)         << "Sparsity not supported by this Filler.";  }};
由上面可以看出,caffe的Xavier实现有三种选择:

(1) 默认情况,方差只考虑输入个数: 
这里写图片描述

(2) FillerParameter_VarianceNorm_FAN_OUT,方差只考虑输出个数: 
这里写图片描述

(3) FillerParameter_VarianceNorm_AVERAGE,方差同时考虑输入和输出个数: 
这里写图片描述