caffe代码阅读3:Filler的实现

来源:互联网 发布:红豆薏米粉的品牌 知乎 编辑:程序博客网 时间:2024/06/08 08:23

一、Filler的作用简介

Filler层的作用实际上就是根据proto中给出的参数对权重进行初始化,初始化的方式有很多种,分别为常量初始化(constant)、高斯分布初始化(gaussian)、positive_unitball初始化、均匀分布初始化(uniform)、xavier初始化、msra初始化、双线性初始化(bilinear)这么几种。

二、Filler类的详细介绍

首先了解一下Filler类的第一个函数:该函数把整个Filler类一下子就看明白了
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. Filler<Dtype>* GetFiller(const FillerParameter& param) {  
  3.   const std::string& type = param.type();  
  4.   if (type == "constant") {  
  5.     return new ConstantFiller<Dtype>(param);  
  6.   } else if (type == "gaussian") {  
  7.     return new GaussianFiller<Dtype>(param);  
  8.   } else if (type == "positive_unitball") {  
  9.     return new PositiveUnitballFiller<Dtype>(param);  
  10.   } else if (type == "uniform") {  
  11.     return new UniformFiller<Dtype>(param);  
  12.   } else if (type == "xavier") {  
  13.     return new XavierFiller<Dtype>(param);  
  14.   } else if (type == "msra") {  
  15.     return new MSRAFiller<Dtype>(param);  
  16.   } else if (type == "bilinear") {  
  17.     return new BilinearFiller<Dtype>(param);  
  18.   } else {  
  19.     CHECK(false) << "Unknown filler name: " << param.type();  
  20.   }  
  21.   return (Filler<Dtype>*)(NULL);  

根据给定的参数获取对应的Filler,由该段代码可以看出proto文件里面对于权重可以有哪些指定的初始化方式。


1)基类Filler

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. class Filler {  
  3.  public:  
  4.  // 构造函数  
  5.   explicit Filler(const FillerParameter& param) : filler_param_(param) {}  
  6.   // 析构函数,并且是虚函数  
  7.   virtual ~Filler() {}  
  8.   // 纯虚函数,继承的子类必须要实现  
  9.   virtual void Fill(Blob<Dtype>* blob) = 0;  
  10.  protected:  
  11.   FillerParameter filler_param_;  
  12. };  // class Filler  

2)继承Filler的类

2-1 常量初始化类

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. class ConstantFiller : public Filler<Dtype> {  
  3.  public:  
  4.   explicit ConstantFiller(const FillerParameter& param)  
  5.       : Filler<Dtype>(param) {}  
  6.   virtual void Fill(Blob<Dtype>* blob) {  
  7.     // 获取数据指针  
  8.     Dtype* data = blob->mutable_cpu_data();  
  9.     // 获取数据长度  
  10.     const int count = blob->count();  
  11.     // 获取常量初始化的常数值  
  12.     const Dtype value = this->filler_param_.value();  
  13.     CHECK(count);  
  14.     for (int i = 0; i < count; ++i) {  
  15.       data[i] = value;//对于每一个元素都初始化为常数值  
  16.     }  
  17.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  18.          << "Sparsity not supported by this Filler.";  
  19.   }  
  20. };  

2-2 均匀分布初始化类

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. class UniformFiller : public Filler<Dtype> {  
  3.  public:  
  4.   explicit UniformFiller(const FillerParameter& param)  
  5.       : Filler<Dtype>(param) {}  
  6.   virtual void Fill(Blob<Dtype>* blob) {  
  7.     // 检查blob中的元素是否为0  
  8.     CHECK(blob->count());  
  9.     // 调用caffe_rng_uniform进行初始化  
  10.     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),  
  11.         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());  
  12.     // 均匀分布初始化是不支持稀疏特性的  
  13.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  14.          << "Sparsity not supported by this Filler.";  
  15.   }  
  16. };  

2-3 高斯分布初始化类(支持稀疏特性)

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. class GaussianFiller : public Filler<Dtype> {  
  3.  public:  
  4.   explicit GaussianFiller(const FillerParameter& param)  
  5.       : Filler<Dtype>(param) {}  
  6.   virtual void Fill(Blob<Dtype>* blob) {  
  7.     Dtype* data = blob->mutable_cpu_data();  
  8.     CHECK(blob->count());  
  9.     // 调用caffe_rng_gaussian初始化、其中输入了高斯分布的均值和标准差  
  10.     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),  
  11.         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());  
  12.     int sparse = this->filler_param_.sparse();  
  13.     // 检查sparse > -1  
  14.     CHECK_GE(sparse, -1);  
  15.     if (sparse >= 0) {//  如果启用稀疏的话  
  16.       // Sparse initialization is implemented for "weight" blobs; i.e. matrices.  
  17.       // These have num == channels == 1; width is number of inputs; height is  
  18.       // number of outputs.  The 'sparse' variable specifies the mean number  
  19.       // of non-zero input weights for a given output.  
  20.       CHECK_GE(blob->num_axes(), 1);  
  21.       // 假设权重的形状是 输出单元个数 X输入单元个数  
  22.       // blob->shape(0) = 输出单元的个数  
  23.       const int num_outputs = blob->shape(0);  
  24.       // 不为0的概率 = 1/输出单元个数  
  25.       // 那么为0的概率= 1 - 1/输出单元个数  
  26.       Dtype non_zero_probability = Dtype(sparse) / Dtype(num_outputs);  
  27.       // 新建一个rand_vec,用户存放伯努利分布(二项分布)所生成的值  
  28.       rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));  
  29.       int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());  
  30.       caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);  
  31.       for (int i = 0; i < blob->count(); ++i) {  
  32.         data[i] *= mask[i];// 每一个数据元素都与生成的二项分布的样本值相乘  
  33.       }  
  34.     }  
  35.   }  
  36.   
  37.  protected:  
  38.   shared_ptr<SyncedMemory> rand_vec_;  
  39. };  

2-4 PositiveUnitballFiller初始化

不懂的可以看http://math.stackexchange.com/questions/520002/unit-ball-with-p-norm
相当于是一个单位球
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. // PositiveUnitballFiller首先用均匀分布填充W  
  2. // 然后将W中的元素按行求和,然后该行每一个的元素都除以该行的和  
  3. template <typename Dtype>  
  4. class PositiveUnitballFiller : public Filler<Dtype> {  
  5.  public:  
  6.   explicit PositiveUnitballFiller(const FillerParameter& param)  
  7.       : Filler<Dtype>(param) {}  
  8.   virtual void Fill(Blob<Dtype>* blob) {  
  9.     Dtype* data = blob->mutable_cpu_data();  
  10.     DCHECK(blob->count());// 我很奇怪为啥这里用DCHECK  
  11.     // 先填充均匀分布到权重  
  12.     caffe_rng_uniform<Dtype>(blob->count(), 0, 1, blob->mutable_cpu_data());  
  13.     // We expect the filler to not be called very frequently, so we will  
  14.     // just use a simple implementation  
  15.     // count / num = 输入的维度  
  16.     int dim = blob->count() / blob->num();  
  17.     CHECK(dim);// 检查输入维度是否小于0  
  18.     for (int i = 0; i < blob->num(); ++i) {// 遍历隐藏单元的个数(或者是输出单元的个数)  
  19.       Dtype sum = 0;  
  20.       for (int j = 0; j < dim; ++j) {  
  21.         sum += data[i * dim + j];//sum += data[i][j] 也就是说要按行求和  
  22.       }  
  23.       for (int j = 0; j < dim; ++j) {  
  24.         data[i * dim + j] /= sum;// 每一行都除以该行的和  
  25.       }  
  26.     }  
  27.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  28.          << "Sparsity not supported by this Filler.";  
  29.   }  
  30. };  

2-5 XavierFiller初始化(用于卷积核)

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. // 这里不明白的就是shape (num, a, b, c) where a * b * c = fan_in and num * b * c = fan_out  
  2.  // 扇入和扇出的定义了  
  3. // 感谢王峰,后来才知道b*c=kernel size  
  4. // a是输入的channel  
  5. // num是输出的channel  
  6. template <typename Dtype>  
  7. class XavierFiller : public Filler<Dtype> {  
  8.  public:  
  9.   explicit XavierFiller(const FillerParameter& param)  
  10.       : Filler<Dtype>(param) {}  
  11.   virtual void Fill(Blob<Dtype>* blob) {  
  12.     CHECK(blob->count());  
  13.     int fan_in = blob->count() / blob->num();  
  14.     int fan_out = blob->count() / blob->channels();  
  15.     Dtype n = fan_in;  // default to fan_in  
  16.     if (this->filler_param_.variance_norm() ==// 如果参数里面定义了方差归一化则n = 扇入+扇出  
  17.         FillerParameter_VarianceNorm_AVERAGE) {  
  18.       n = (fan_in + fan_out) / Dtype(2);  
  19.     } else if (this->filler_param_.variance_norm() ==  
  20.         FillerParameter_VarianceNorm_FAN_OUT) {  
  21.       n = fan_out;  
  22.     }  
  23.     Dtype scale = sqrt(Dtype(3) / n);// scale = \frac{sqrt{3}}{n}  
  24.     // 然后用[-scale,scale]的均匀分布初始化  
  25.     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,  
  26.         blob->mutable_cpu_data());  
  27.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  28.          << "Sparsity not supported by this Filler.";  
  29.   }  
  30. };  

2-6 MSRAFiller初始化方式(用于卷积核)

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. class MSRAFiller : public Filler<Dtype> {  
  3.  public:  
  4.   explicit MSRAFiller(const FillerParameter& param)  
  5.       : Filler<Dtype>(param) {}  
  6.   virtual void Fill(Blob<Dtype>* blob) {  
  7.     CHECK(blob->count());  
  8.     int fan_in = blob->count() / blob->num();  
  9.     int fan_out = blob->count() / blob->channels();  
  10.     Dtype n = fan_in;  // default to fan_in  
  11.     if (this->filler_param_.variance_norm() ==  
  12.         FillerParameter_VarianceNorm_AVERAGE) {  
  13.       n = (fan_in + fan_out) / Dtype(2);  
  14.     } else if (this->filler_param_.variance_norm() ==  
  15.         FillerParameter_VarianceNorm_FAN_OUT) {  
  16.       n = fan_out;  
  17.     }  
  18.     // 标准差是\sqrt{\frac{2}{n}}  
  19.     Dtype std = sqrt(Dtype(2) / n);  
  20.     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(0), std,  
  21.         blob->mutable_cpu_data());  
  22.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  23.          << "Sparsity not supported by this Filler.";  
  24.   }  
  25. };  

2-7 BilinearFiller初始化(用户反卷积核)

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. // 反卷积所用的初始化,不支持稀疏特性  
  2.  // 没研究过。。。也不知道  
  3. template <typename Dtype>  
  4. class BilinearFiller : public Filler<Dtype> {  
  5.  public:  
  6.   explicit BilinearFiller(const FillerParameter& param)  
  7.       : Filler<Dtype>(param) {}  
  8.   virtual void Fill(Blob<Dtype>* blob) {  
  9.     CHECK_EQ(blob->num_axes(), 4) << "Blob must be 4 dim.";  
  10.     CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";  
  11.     Dtype* data = blob->mutable_cpu_data();  
  12.     // f是宽度除以2  
  13.     int f = ceil(blob->width() / 2.);  
  14.     // c的含义不明白  
  15.     float c = (2 * f - 1 - f % 2) / (2. * f);  
  16.     for (int i = 0; i < blob->count(); ++i) {  
  17.       float x = i % blob->width();// x表示列的索引  
  18.       float y = (i / blob->width()) % blob->height();// 行的索引%宽度  
  19.       data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));  
  20.     }  
  21.     CHECK_EQ(this->filler_param_.sparse(), -1)  
  22.          << "Sparsity not supported by this Filler.";  
  23.   }  
  24. };  

三、与Filler类相关类的介绍

因为Filler用到了关于随机数生成的一些方法,下面来看下math_function的相关实现:

(1) 高斯分布随机数的生成:

CPU上的实现(直接调用Boost的库了)
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. void caffe_rng_gaussian(const int n, const Dtype a,  
  3.                         const Dtype sigma, Dtype* r) {  
  4.   CHECK_GE(n, 0);  
  5.   CHECK(r);  
  6.   CHECK_GT(sigma, 0);  
  7.   // 直接调用boost中的正太分布了。  
  8.   boost::normal_distribution<Dtype> random_distribution(a, sigma);  
  9.   boost::variate_generator<caffe::rng_t*, boost::normal_distribution<Dtype> >  
  10.       variate_generator(caffe_rng(), random_distribution);  
  11.   for (int i = 0; i < n; ++i) {  
  12.     r[i] = variate_generator();  
  13.   }  
  14. }  
GPU的实现(直接调用CUDA的库了)
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <>  
  2. void caffe_gpu_rng_gaussian(const int n, const float mu, const float sigma,  
  3.                             float* r) {  
  4.   CURAND_CHECK(  
  5.       curandGenerateNormal(Caffe::curand_generator(), r, n, mu, sigma));  
  6. }  
  7.   
  8. template <>  
  9. void caffe_gpu_rng_gaussian(const int n, const double mu, const double sigma,  
  10.                             double* r) {  
  11.   CURAND_CHECK(  
  12.       curandGenerateNormalDouble(Caffe::curand_generator(), r, n, mu, sigma));  
  13. }  

(2)均匀分布随机数的生成:

CPU:
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. void caffe_rng_uniform(const int n, const Dtype a, const Dtype b, Dtype* r) {  
  3.   CHECK_GE(n, 0);  
  4.   CHECK(r);  
  5.   CHECK_LE(a, b);  
  6.   // 调用Boost的库  
  7.   boost::uniform_real<Dtype> random_distribution(a, caffe_nextafter<Dtype>(b));  
  8.   boost::variate_generator<caffe::rng_t*, boost::uniform_real<Dtype> >  
  9.       variate_generator(caffe_rng(), random_distribution);  
  10.   for (int i = 0; i < n; ++i) {  
  11.     r[i] = variate_generator();  
  12.   }  
  13. }  
GPU:
[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. void caffe_gpu_rng_uniform(const int n, unsigned int* r) {  
  2.   CURAND_CHECK(curandGenerate(Caffe::curand_generator(), r, n));  
  3. }  
  4.   
  5. template <>  
  6. void caffe_gpu_rng_uniform<float>(const int n, const float a, const float b,  
  7.                                   float* r) {  
  8.   CURAND_CHECK(curandGenerateUniform(Caffe::curand_generator(), r, n));  
  9.   const float range = b - a;  
  10.   if (range != static_cast<float>(1)) {  
  11.     caffe_gpu_scal(n, range, r);  
  12.   }  
  13.   if (a != static_cast<float>(0)) {  
  14.     caffe_gpu_add_scalar(n, a, r);  
  15.   }  
  16. }  
  17.   
  18. template <>  
  19. void caffe_gpu_rng_uniform<double>(const int n, const double a, const double b,  
  20.                                    double* r) {  
  21.   CURAND_CHECK(curandGenerateUniformDouble(Caffe::curand_generator(), r, n));  
  22.   const double range = b - a;  
  23.   if (range != static_cast<double>(1)) {  
  24.     caffe_gpu_scal(n, range, r);  
  25.   }  
  26.   if (a != static_cast<double>(0)) {  
  27.     caffe_gpu_add_scalar(n, a, r);  
  28.   }  
  29. }  

(3)伯努利分布(二项分布)随机数的生成(竟然没有GPU上的代码。。。)

[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. template <typename Dtype>  
  2. void caffe_rng_bernoulli(const int n, const Dtype p, int* r) {  
  3.   CHECK_GE(n, 0);  
  4.   CHECK(r);  
  5.   CHECK_GE(p, 0);  
  6.   CHECK_LE(p, 1);  
  7.   boost::bernoulli_distribution<Dtype> random_distribution(p);  
  8.   boost::variate_generator<caffe::rng_t*, boost::bernoulli_distribution<Dtype> >  
  9.       variate_generator(caffe_rng(), random_distribution);  
  10.   for (int i = 0; i < n; ++i) {  
  11.     r[i] = variate_generator();  
  12.   }  
  13. }  
  14. void caffe_rng_bernoulli(const int n, const Dtype p, unsigned int* r) {  
  15.   CHECK_GE(n, 0);  
  16.   CHECK(r);  
  17.   CHECK_GE(p, 0);  
  18.   CHECK_LE(p, 1);  
  19.   boost::bernoulli_distribution<Dtype> random_distribution(p);  
  20.   boost::variate_generator<caffe::rng_t*, boost::bernoulli_distribution<Dtype> >  
  21.       variate_generator(caffe_rng(), random_distribution);  
  22.   for (int i = 0; i < n; ++i) {  
  23.     r[i] = static_cast<unsigned int>(variate_generator());  
  24.   }  
  25. }  

四、总结

主要介绍了Filler中初始化权重各个算法的具体的实现,具体原理可以参考相关的论文。关于Filler其实没啥可以深挖的。已经被挖得差不多了。
1 0
原创粉丝点击