caffe loss

来源:互联网 发布:淘宝买枪须输入什么 编辑:程序博客网 时间:2024/05/17 21:57

最近需要修改caffe的loss, 所以重新回顾了以下前向反向的过程.当caffe train的时候,过程:
1.最先进入solver.cpp里面的Solver函数,根据最大迭代次数设置迭代多少次.,调用Step(maxiter_)
2.solver.cpp里面的Step函数,有个average_loss(从solver的param里面读取,我没有设置但是他默认是1,后面的updatesmoothloss有用到.losses_清空,smoothed_loss设置伪0.
接下来进入while大循环,根据iter,

     1. 每次开始循环前都clearparamdiffs()并且        执行下面的循环.iter_size()这里是1                `for (int i = 0; i < param_.iter_size(); ++i) {                      loss += net_->ForwardBackward();//进入net.hpp                }`     2. net.hpp, 分别执行forward(&loss) 和backward()         2.1. net.cpp的forward(&loss),调用ForwardFromTo(0, layers_.size() - 1),所有的层,并返回loss.             2.1.1.net.cpp的ForwardFromTo(start, end),for循环执行            ```             Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);            ```            也就是进入每一层单独的forward                 2.1.1.1. layer.hpp的Forward(bottom, top),真正每一层的forward,并返回每一层的loss, 以CPU为例,会执行                     2.1.1.1.1 Reshape                     2.1.1.1.2 Forward_cpu()                     2.1.1.1.3`    for (int top_id = 0; top_id < top.size(); ++top_id) {                                              if (!this->loss(top_id)) { continue; }                                              const int count = top[top_id]->count();                                              const Dtype* data = top[top_id]->cpu_data();                                              const Dtype* loss_weights = top[top_id]->cpu_diff();                                              loss += caffe_cpu_dot(count, data, loss_weights);                                }`                                只有有loss的层才会调用,将top的data和cpu_diff点乘??为什么?最终返回loss             2.1.2. 将返回的loss累加作为最终loss返回(net.cpp).         2.2.  net.cpp的forwardfromto累加loss,并返回最终俄     3. loss /=param_.iter_size(), 进入UpdateSmoothedLoss(loss, start_iter, average_loss);     4. ApplyUpdate() ,由于是虚函数所以这里会进入sgdsolver的ApplyUpdate     5. ClipGradients     6. 对于所有可学习的参数分别进行:         6.1.  Normalize         6.2.  Regularize         6.3.  可以添加自己的Regularize(例如Group Lasso)         6.4. ComputeUpdateValue     10. 在整个net的级别进行Update    从上面可以看到,可以在6.3部添加自己的loss(或者之前的叠加loss的地方,具体需求具体分析)
1 0