caffe源码解析 — solver.cpp

来源:互联网 发布:战翼幻影全能板数据 编辑:程序博客网 时间:2024/06/06 01:13

转载自:http://blog.csdn.net/qq_16055159/article/details/45068147

Solver<Dtype>::Solver(const SolverParameter& param) 
功能:构造函数 
步骤:初始化两个Net类,net_和test_net_,并调用Init()函数 
输入:SolverParameter类型的param 
输出:无

Solver<Dtype>::Solver(const string& param_file) 
功能:构造函数 
步骤:初始化两个Net类,net_和test_net_,并调用Init()函数 
输入:string类型的param_file 
输出:无

void Solver<Dtype>::Init(const SolverParameter& param) 
功能:初始化网络 
步骤: 
1. 设置随机数种子 
2. 申请一块Net空间以下面的构造函数进行初始化 
param_file=train_net_,net_指向这块空间 
3. 如果有test_net,则申请一块Net空间,test_net_指向这块空间 
输入:SolverParameter类型的param 
输出:无

<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">Net<Dtype><span class="hljs-value" style="box-sizing: border-box;">::Net(const string& param_file) {  NetParameter param;</span>  <span class="hljs-function" style="box-sizing: border-box;">ReadNetParamsFromTextFileOrDie(param_file, &param)</span>;  <span class="hljs-function" style="box-sizing: border-box;">Init(param)</span>;}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li></ul>

void Solver<Dtype>::Solve(const char* resume_file) 
功能:训练网络 
步骤: 
1. 设置Caffe的mode(GPU还是CPU) 
2. 如果是GPU且有GPU芯片的ID,则设置GPU 
3. 设置当前阶段(TRAIN还是TEST/TRAIN) 
4. 调用PreSolve函数:PreSolve() 
5. 调用Restore函数:Restore(resume_file) 
6. 调用一遍Test(),判断内存是否够 
7. 对于每一次训练时的迭代(遍历整个网络):while (iter_++ < param_.max_iter())

  1. 计算loss:loss = net_->ForwardBackward(bottom_vec)其中:
<code class="hljs cpp has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;">*************** ForwardBackward() ************Dtype ForwardBackward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>* > & bottom) {    Dtype loss;    Forward(bottom, &loss);    Backward();    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> loss;  }*************** Forward() ***********<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& Net<Dtype>::Forward(    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*> & bottom, Dtype* loss) {  <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// Copy bottom to internal bottom</span>  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < bottom.size(); ++i)     net_input_blobs_[i]->CopyFrom(*bottom[i]){;  }  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> ForwardPrefilled(loss);}*************** ForwardPrefilled() ************<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss != NULL) {    *loss = Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0.</span>);  }  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; i < layers_.size(); ++i) {    <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">// LOG(ERROR) << "Forwarding " << layer_names_[i];</span>    Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (loss != NULL) {      *loss += layer_loss;<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//对于非loss层都会返回0:return Dtype(0.);</span>    }  }  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> net_output_blobs_;}*************** Layer::Forward() ************<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> Dtype Layer<Dtype>::Forward(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>& bottom,    <span class="hljs-stl_container" style="box-sizing: border-box;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">vector</span><Blob<Dtype></span>*>* top) {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">switch</span> (Caffe::mode()) {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::CPU:    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Forward_cpu(bottom, top);<span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">//虚函数,不同层有不同层的计算方法</span>  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">case</span> Caffe::GPU:    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Forward_gpu(bottom, top);  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">default</span>:    LOG(FATAL) << <span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">"Unknown caffe mode."</span>;    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">return</span> Dtype(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>);  }}*************** Backward() ************<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> Net<Dtype>::Backward() {  <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> (<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> i = layers_.size() - <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>; i >= <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span>; --i) {    <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">if</span> (layer_need_backward_[i]) {      layers_[i]->Backward(top_vecs_[i], <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>, &bottom_vecs_[i]);    }  }}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li><li style="box-sizing: border-box; padding: 0px 5px;">14</li><li style="box-sizing: border-box; padding: 0px 5px;">15</li><li style="box-sizing: border-box; padding: 0px 5px;">16</li><li style="box-sizing: border-box; padding: 0px 5px;">17</li><li style="box-sizing: border-box; padding: 0px 5px;">18</li><li style="box-sizing: border-box; padding: 0px 5px;">19</li><li style="box-sizing: border-box; padding: 0px 5px;">20</li><li style="box-sizing: border-box; padding: 0px 5px;">21</li><li style="box-sizing: border-box; padding: 0px 5px;">22</li><li style="box-sizing: border-box; padding: 0px 5px;">23</li><li style="box-sizing: border-box; padding: 0px 5px;">24</li><li style="box-sizing: border-box; padding: 0px 5px;">25</li><li style="box-sizing: border-box; padding: 0px 5px;">26</li><li style="box-sizing: border-box; padding: 0px 5px;">27</li><li style="box-sizing: border-box; padding: 0px 5px;">28</li><li style="box-sizing: border-box; padding: 0px 5px;">29</li><li style="box-sizing: border-box; padding: 0px 5px;">30</li><li style="box-sizing: border-box; padding: 0px 5px;">31</li><li style="box-sizing: border-box; padding: 0px 5px;">32</li><li style="box-sizing: border-box; padding: 0px 5px;">33</li><li style="box-sizing: border-box; padding: 0px 5px;">34</li><li style="box-sizing: border-box; padding: 0px 5px;">35</li><li style="box-sizing: border-box; padding: 0px 5px;">36</li><li style="box-sizing: border-box; padding: 0px 5px;">37</li><li style="box-sizing: border-box; padding: 0px 5px;">38</li><li style="box-sizing: border-box; padding: 0px 5px;">39</li><li style="box-sizing: border-box; padding: 0px 5px;">40</li><li style="box-sizing: border-box; padding: 0px 5px;">41</li><li style="box-sizing: border-box; padding: 0px 5px;">42</li><li style="box-sizing: border-box; padding: 0px 5px;">43</li><li style="box-sizing: border-box; padding: 0px 5px;">44</li><li style="box-sizing: border-box; padding: 0px 5px;">45</li><li style="box-sizing: border-box; padding: 0px 5px;">46</li><li style="box-sizing: border-box; padding: 0px 5px;">47</li><li style="box-sizing: border-box; padding: 0px 5px;">48</li><li style="box-sizing: border-box; padding: 0px 5px;">49</li><li style="box-sizing: border-box; padding: 0px 5px;">50</li><li style="box-sizing: border-box; padding: 0px 5px;">51</li></ul>

2.调用ComputeUpdateValue函数:ComputeUpdateValue() 
3. 输出loss 
4. 达到test_interval时调用Test() 
5. 达到snapshot时调用snapshot() 
6. 调用Snapshot函数:Snapshot() 
输入:char*类型的resume_file 
输出:无

void Solver<Dtype>::Test() 
功能:测试网络 
输入:无 
输出:无 
步骤: 
1. 设置当前阶段(TRAIN还是TEST/TEST) 
2. 将test_net_指向net_,即对同一个网络操作 
3. 对于每一次测试时的迭代:for (int i = 0; i < param_.test_iter(); ++i)

  1. 用下面语句给result赋值net_output_blobs_ //result是所有的输出层blob 
    同时得到这次测试的iter_loss 
    result = test_net_->Forward(bottom_vec, &iter_loss)
  2. 第一次测试时: 
    1. 取每一个输出层的blob result_vec = result[j]->cpu_data()
    2. 把每一个blob的数据(降为一维)存入一个vector–“test_score”
  3. 不是第一次测试: 
    1. 用 test_score[idx++] += result_vec[k] 
      而不是 test_score.push_back(result_vec[k])
    2. 把输出层对应位置的blob值累加 
      test_score[idx++] += result_vec[k]
  4. 是否要输出Test loss
  5. 是否要输出test_score
  6. 设置当前阶段(TRAIN还是TEST/TRAIN)

void Solver<Dtype>::Snapshot() 
功能:输出当前网络状态到一个文件中,不重要 
输入:无 
输出:无

void Solver<Dtype>::Restore(const char* state_file) 
功能:从一个文件中读入网络状态,并可以从那个状态恢复,不重要 
输入:文件名 
输出:无

Dtype SGDSolver<Dtype>::GetLearningRate() 
功能:得到学习率 
步骤: 
1. 得到学习率类型 const string& lr_policy = this->param_.lr_policy() 
2. 判断学习率类型(注释有介绍) 
3. 返回学习率 
输入:无 
输出:Dtype类型的rate

void SGDSolver<Dtype>::PreSolve() 
功能:提前训练 
步骤: 
1. 将训练网络net_的参数读到net_params net_params = this->net_->params() 
其中params_是一个存blob指针的vector 
2. 清空历史残留值 
3. 向history压入与网络的每一层blob相同大小的空间 
输入:无 
输出:无

void SGDSolver<Dtype>::ComputeUpdateValue() 
功能:用随机梯度下降法计算更新值 
输入:无 
输出:无 
步骤: 
1. (所有的)读取网络参数net_params,网络学习速率 net_params_lr, 
权值衰减net_params_weight_decay 读取学习速率rate 
2. (当前层)读取动量,权值衰减 
3. 如果是CPU: 
对于每一次层:

  1. 计算local_rate,local_decay
  2. 调用caffe_cpu_axpby,caffe_axpy,caffe_copy函数:
<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;">caffe_cpu_axpby(net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">count()</span>, local_rate,              net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">cpu_diff()</span>, momentum, history_[param_id]-><span class="hljs-function" style="box-sizing: border-box;">mutable_cpu_data()</span>)</span>;</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
<code class="language-c++ hljs scss has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-function" style="box-sizing: border-box;">caffe_axpy(net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">count()</span>, local_decay*local_rate,  net_params[param_id]-><span class="hljs-function" style="box-sizing: border-box;">cpu_data()</span>,history_[param_id]-><span class="hljs-function" style="box-sizing: border-box;">mutable_cpu_data()</span>)</span>;</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
<code class="hljs objectivec has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> caffe_cpu_axpby<<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>>(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> beta, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y){  cblas_saxpby(N, alpha, X, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, beta, Y, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);}其中:<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">inline</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> cblas_saxpby(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> incX, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> beta, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> incY){  cblas_sscal(N, beta, Y, incY);  cblas_saxpy(N, alpha, X, incX, Y, incY);}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li></ul>

caffe_cpu_axpby调用了cblas_saxpby,即调用了cblas_sscal和cblas_saxpy

<code class="hljs cs has-numbering" style="display: block; padding: 0px; background-color: transparent; color: inherit; box-sizing: border-box; font-family: 'Source Code Pro', monospace;font-size:undefined; white-space: pre; border-top-left-radius: 0px; border-top-right-radius: 0px; border-bottom-right-radius: 0px; border-bottom-left-radius: 0px; word-wrap: normal; background-position: initial initial; background-repeat: initial initial;"><span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">void</span> caffe_axpy<<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>>(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">int</span> N, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span> alpha, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">const</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* X,<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">float</span>* Y){  cblas_saxpy(N, alpha, X, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, Y, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>);}</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; background-color: rgb(238, 238, 238); top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right;"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li></ul>

caffe_axpy调用了cblas_saxpy,即调用了cblas_saxpy 
所以caffe_cpu_axpby比caffe_axpy多输入了一个beta参数,多调用了cblas_sscal(N, beta, Y, incY); 
4. GPU同理

void SGDSolver<Dtype>::SnapshotSolverState(SolverState* state) 

void SGDSolver<Dtype>::RestoreSolverState(const SolverState& state) 

0 0
原创粉丝点击