Caffe实践C++源码解读(2):走入Solver

来源:互联网 发布:利驰数据 编辑:程序博客网 时间:2024/04/30 15:07

在理解如何使caffe运行之后,我们要理解它是如何运行的,即了解Solver类的Solve()函数做了什么,对于Solver类中如何初始化网络以及其他参数,有兴趣的可以深入研究。

源码中Solver()函数是有参数形式的

  // The main entry of the solver function. In default, iter will be zero. Pass  // in a non-zero iter number to resume training for a pre-trained net.  virtual void Solve(const char* resume_file = NULL);  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
各位一看就明白了吧。再看Solve函数的定义

template <typename Dtype>void Solver<Dtype>::Solve(const char* resume_file) {  CHECK(Caffe::root_solver());  LOG(INFO) << "Solving " << net_->name();  LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();  // Initialize to false every time we start solving.  requested_early_exit_ = false;  if (resume_file) {    LOG(INFO) << "Restoring previous solver status from " << resume_file;    Restore(resume_file);  }
传入参数resume_file是用于继续中断的训练的,既然caffe.cpp中的train()函数中已经进行了此操作,此处就不需要再传入resume_file参数了。
然后Solver就直接进入训练模式了,即Step函数,传入参数为循环的次数,此参数在solver.txt文件中定义的max_iter和resume_file加载的iter_参数的差。

  // For a network that is trained by the solver, no bottom or top vecs  // should be given, and we will just provide dummy vecs.  int start_iter = iter_;  Step(param_.max_iter() - iter_);

在进入Step函数之前,我们继续往下看,训练完成后caffe会保存当前模型

  // If we haven't already, save a snapshot after optimization, unless  // overridden by setting snapshot_after_train := false  if (param_.snapshot_after_train()      && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {    Snapshot();  }
如果solver.txt中提供了test网络,那么会在训练完成后进行一次测试
  // After the optimization is done, run an additional train and test pass to  // display the train and test loss/outputs if appropriate (based on the  // display and test_interval settings, respectively).  Unlike in the rest of  // training, for the train net we only run a forward pass as we've already  // updated the parameters "max_iter" times -- this final pass is only done to  // display the loss, which is computed in the forward pass.  if (param_.display() && iter_ % param_.display() == 0) {    int average_loss = this->param_.average_loss();    Dtype loss;    net_->Forward(&loss);    UpdateSmoothedLoss(loss, start_iter, average_loss);    LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;  }  if (param_.test_interval() && iter_ % param_.test_interval() == 0) {    TestAll();  }

在Step函数中通过while循环迭代训练,并且如果设置有测试网络,在设置条件满足时,每次循环会先对当前网络进行测试

  while (iter_ < stop_iter) {    // zero-init the params    net_->ClearParamDiffs();    if (param_.test_interval() && iter_ % param_.test_interval() == 0        && (iter_ > 0 || param_.test_initialization())) {      if (Caffe::root_solver()) {        TestAll();      }
测试完成后,如何没有终止训练,将继续训练,此处的iter_size默认值是1,主要作用是SGD中参数更新频率,即训练iter_size后更新网络,此时训练的总样本数为train.txt中定义的batch_size * iter_size。

    for (int i = 0; i < param_.iter_size(); ++i) {      loss += net_->ForwardBackward();    }
之后调用ApplyUpdate();更新权值和偏置,更新方法后续再聊。

Step中的测试与caffe.cpp中的test类似,主要是检测当前网络训练状态,可以根据任务状态提前终止训练,比如测试的损失函数达到一定范围。









原创粉丝点击