(Caffe,LeNet)权值更新(七)

来源:互联网 发布:社交网络 肖恩帕克 编辑:程序博客网 时间:2024/05/19 07:08

本文地址:http://blog.csdn.net/mounty_fsc/article/details/51588773

在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。

1 模型优化

1.1 损失函数

损失函数L(W)可由经验损失加正则化项得到,如下,其中X(i)为输入样本;fW为某样本的损失函数;N为mini-batch的样本数量;r(W)为以权值为λ的正则项。

L(W)1NNifW(X(i))+λr(W)

在caffe中,可以分为三个阶段:

  1. 前向计算阶段,这个阶段计算fW
  2. 反向传播阶段,这个阶段计算fW
  3. 权值更新阶段,这个阶段通过fW,r(W)等计算ΔW从而更新W

1.2 随机梯度下降

在lenet中,solver的类型为SGD(Stochastic gradient descent)

SGD通过以下公式对权值进行更新:

Wt+1=Wt+Vt+1
Vt+1=μVtαL(Wt)

其中,Wt+1为第t+1轮的权值;Vt+1为第t+1轮的更新(也可以写作ΔWt+1);μ为上一轮更新的权重;α为学习率;L(Wt)为loss对权值的求导

2 代码分析

2.1 ApplyUpdate

void SGDSolver<Dtype>::ApplyUpdate() {  // 获取该轮迭代的学习率(learning rate)  Dtype rate = GetLearningRate();  // 对每一层网络的权值进行更新  // 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数  // 每层分别有参数与偏置参数两项参数  // 因而`learnable_params_`的size为8.  for (int param_id = 0; param_id < this->net_->learnable_params().size();       ++param_id) {    // 归一化,iter_size为1不需要,因而lenet不需要。    // 此处的归一化内容很简单,仅仅是iter_size大于1时值再除以iter_size    Normalize(param_id);    // 正则化    Regularize(param_id);    // 计算更新值\delta w    ComputeUpdateValue(param_id, rate);  }  // 更新权值  this->net_->Update();}

说明:

  1. lenet中学习参数设置可从lenet_solver.prototxt中查到

    # The base learning rate, momentum and the weight decay of the network.base_lr: 0.01momentum: 0.9weight_decay: 0.0005# The learning rate policylr_policy: "inv"gamma: 0.0001power: 0.75
  2. 获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是inv的策略,是一种没一轮迭代学习率都改变的策略。

      // The learning rate decay policy. The currently implemented learning rate  // policies are as follows:  //    - fixed: always return base_lr.  //    - step: return base_lr * gamma ^ (floor(iter / step))  //    - exp: return base_lr * gamma ^ iter  //    - inv: return base_lr * (1 + gamma * iter) ^ (- power)  //    - multistep: similar to step but it allows non uniform steps defined by  //      stepvalue  //    - poly: the effective learning rate follows a polynomial decay, to be  //      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)  //    - sigmoid: the effective learning rate follows a sigmod decay  //      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))  //  // where base_lr, max_iter, gamma, step, stepvalue and power are defined  // in the solver parameter protocol buffer, and iter is the current iteration.

2.2 Regularize

该函数实际执行以下公式

losswij=decaywij+losswij

代码如下:

void SGDSolver<Dtype>::Regularize(int param_id) {  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  const vector<float>& net_params_weight_decay =      this->net_->params_weight_decay();  Dtype weight_decay = this->param_.weight_decay();  string regularization_type = this->param_.regularization_type();  // local_decay = 0.0005 in lenet  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];  ...      if (regularization_type == "L2") {        // axpy means ax_plus_y. i.e., y = a*x + y        caffe_axpy(net_params[param_id]->count(),            local_decay,            net_params[param_id]->cpu_data(),            net_params[param_id]->mutable_cpu_diff());      }   ...}

2.3 ComputeUpdateValue

该函数实际执行以下公式
vij=lr_ratelosswij+momentumvij
losswij=vij

代码如下:

void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();  const vector<float>& net_params_lr = this->net_->params_lr();  // momentum = 0.9 in lenet  Dtype momentum = this->param_.momentum();  // local_rate = lr_mult * global_rate  // lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置  Dtype local_rate = rate * net_params_lr[param_id];  // Compute the update to history, then copy it to the parameter diff.  ...    // axpby means ax_plus_by. i.e., y = ax + by    // 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,              net_params[param_id]->cpu_diff(), momentum,              history_[param_id]->mutable_cpu_data());    // 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中    caffe_copy(net_params[param_id]->count(),        history_[param_id]->cpu_data(),        net_params[param_id]->mutable_cpu_diff());   ... }

2.4 net_->Update

实际执行以下公式:
wij=wij+(1)losswij

caffe_axpy<Dtype>(count_, Dtype(-1),        static_cast<const Dtype*>(diff_->cpu_data()),        static_cast<Dtype*>(data_->mutable_cpu_data()));

参考文献:

[1]. http://caffe.berkeleyvision.org/tutorial/solver.html

0 0
原创粉丝点击