Caffe中权值是怎么更新的

来源:互联网 发布:深圳太极软件拖欠工资 编辑:程序博客网 时间:2024/05/19 05:29

网址:http://blog.csdn.net/mounty_fsc/article/details/51588773


(Caffe,LeNet)权值更新(七)

在Solver::ApplyUpdate()函数中,根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习

1 模型优化

1.1 损失函数

损失函数L(W)可由经验损失加正则化项得到,如下,其中X(i)为输入样本;fW为某样本的损失函数;N为mini-batch的样本数量;r(W)为以权值为λ的正则项。

L(W)1NNifW(X(i))+λr(W)

在caffe中,可以分为三个阶段:

  1. 前向计算阶段,这个阶段计算fW
  2. 反向传播阶段,这个阶段计算fW
  3. 权值更新阶段,这个阶段通过fW,r(W)等计算ΔW从而更新W

1.2 随机梯度下降

在lenet中,solver的类型为SGD(Stochastic gradient descent)

SGD通过以下公式对权值进行更新:

Wt+1=Wt+Vt+1 
Vt+1=μVtαL(Wt)

其中,Wt+1为第t+1轮的权值;Vt+1为第t+1轮的更新(也可以写作ΔWt+1);μ为上一轮更新的权重;α为学习率;L(Wt)为loss对权值的求导

2 代码分析

2.1 ApplyUpdate

<code class="language-c++ hljs lasso has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-literal" style="color: rgb(0, 102, 102); box-sizing: border-box;">void</span> SGDSolver<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span>Dtype<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">></span><span class="hljs-tag" style="color: rgb(0, 102, 102); box-sizing: border-box;">::ApplyUpdate</span>() {  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 获取该轮迭代的学习率(learning rate)</span>  Dtype rate <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> GetLearningRate();  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 对每一层网络的权值进行更新</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 在lenet中,只有`conv1`,`conv2`,`ip1`,`ip2`四层有参数</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 每层分别有参数与偏置参数两项参数</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 因而`learnable_params_`的size为8.</span>  for (int param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">=</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; param_id <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;"><</span> this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>learnable_params()<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">.</span>size();       <span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">++</span>param_id) {    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 归一化,iter_size为1不需要,因而lenet不需要</span>    Normalize(param_id);    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 正则化</span>    Regularize(param_id);    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算更新值\delta w</span>    ComputeUpdateValue(param_id, rate);  }  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 更新权值</span>  this<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>net_<span class="hljs-subst" style="color: rgb(0, 0, 0); box-sizing: border-box;">-></span>Update();}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li></ul>

说明:

  1. lenet中学习参数设置可从lenet_solver.prototxt中查到

    <code class="language-c++ hljs avrasm has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The base learning rate, momentum and the weight decay of the network.</span><span class="hljs-label" style="box-sizing: border-box;">base_lr:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.01</span><span class="hljs-label" style="box-sizing: border-box;">momentum:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.9</span><span class="hljs-label" style="box-sizing: border-box;">weight_decay:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0005</span><span class="hljs-preprocessor" style="color: rgb(68, 68, 68); box-sizing: border-box;"># The learning rate policy</span><span class="hljs-label" style="box-sizing: border-box;">lr_policy:</span> <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"inv"</span><span class="hljs-label" style="box-sizing: border-box;">gamma:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.0001</span><span class="hljs-label" style="box-sizing: border-box;">power:</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.75</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li></ul>
  2. 获取学习率函数ApplyUpdate代码此处不给出,查看注释(以及caffe.proto)可知有如下学习率获取策略。在Lenet中采用的是inv的策略,是一种没一轮迭代学习率都改变的策略。

    <code class="language-c++ hljs cs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// The learning rate decay policy. The currently implemented learning rate</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// policies are as follows:</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - fixed: always return base_lr.</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - step: return base_lr * gamma ^ (floor(iter / step))</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - exp: return base_lr * gamma ^ iter</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - inv: return base_lr * (1 + gamma * iter) ^ (- power)</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - multistep: similar to step but it allows non uniform steps defined by</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      stepvalue</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - poly: the effective learning rate follows a polynomial decay, to be</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//    - sigmoid: the effective learning rate follows a sigmod decay</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//      return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// where base_lr, max_iter, gamma, step, stepvalue and power are defined</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// in the solver parameter protocol buffer, and iter is the current iteration.</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li></ul>

2.2 Regularize

该函数实际执行以下公式

losswij=decaywij+losswij

代码如下:

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::Regularize(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id) {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params();  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_weight_decay =      <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_weight_decay();  Dtype weight_decay = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.weight_decay();  <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">string</span> regularization_type = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.regularization_type();  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_decay = 0.0005 in lenet</span>  Dtype local_decay = weight_decay * net_params_weight_decay[param_id];  ...      <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (regularization_type == <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"L2"</span>) {        <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpy means ax_plus_y. i.e., y = a*x + y</span>        caffe_axpy(net_params[param_id]->count(),            local_decay,            net_params[param_id]->cpu_data(),            net_params[param_id]->mutable_cpu_diff());      }   ...}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li></ul>

2.3 ComputeUpdateValue

该函数实际执行以下公式 
vij=lr_ratelosswij+momentumvij 
losswij=vij

代码如下:

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> SGDSolver<Dtype>::ComputeUpdateValue(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> param_id, Dtype rate) {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& net_params = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->learnable_params();  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>></span>& net_params_lr = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->net_->params_lr();  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// momentum = 0.9 in lenet</span>  Dtype momentum = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">this</span>->param_.momentum();  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// local_rate = lr_mult * global_rate</span>  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// lr_mult为该层学习率乘子,在lenet_train_test.prototxt中设置</span>  Dtype local_rate = rate * net_params_lr[param_id];  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// Compute the update to history, then copy it to the parameter diff.</span>  ...    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// axpby means ax_plus_by. i.e., y = ax + by</span>    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 计算新的权值更新变化值 \delta w,结果保存在历史权值变化中</span>    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,              net_params[param_id]->cpu_diff(), momentum,              history_[param_id]->mutable_cpu_data());    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// 从历史权值变化中把变化值 \delta w 保存到历史权值中diff中</span>    caffe_copy(net_params[param_id]->count(),        history_[param_id]->cpu_data(),        net_params[param_id]->mutable_cpu_diff());   ... }</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li></ul>

2.4 net_->Update

实际执行以下公式: 
wij=wij+(1)losswij

<code class="language-c++ hljs cpp has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">caffe_axpy<Dtype>(count_, Dtype(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>),        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> Dtype*>(diff_->cpu_data()),        <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">static_cast</span><Dtype*>(data_->mutable_cpu_data()));</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li></ul>

参考文献:

[1]. http://caffe.berkeleyvision.org/tutorial/solver.html


0 0
原创粉丝点击