(Caffe,LeNet)网络训练流程(二)

来源:互联网 发布:linux 调整根目录大小 编辑:程序博客网 时间:2024/06/01 09:44

目录(?)[+]

  1. 程序入口
  2. Solver的创建
  3. SolverSolve函数
  4. SolverStep函数
    1. 1 SolverTestAll函数
    2. 2 NetForwardBackward函数
    3. 3 SolverApplyUpdate函数
  5. 训练完毕

本文地址:http://blog.csdn.net/mounty_fsc/article/details/51090114

在训练lenet的train_lenet.sh中内容为:

./build/tools/caffe train –solver=examples/mnist/lenet_solver.prototxt

由此可知,训练网咯模型是由tools/caffe.cpp生成的工具caffe在模式train下完成的。
初始化过程总的来说,从main()train()中创建Solver,在Solver中创建Net,在Net中创建Layer.

1 程序入口

  • 找到caffe.cppmain函数中,通过GetBrewFunction(caffe::string(argv[1]))()调用执行train()函数。
  • train中,通过参数-examples/mnist/lenet_solver.prototxtsolver参数读入solver_param中。
  • 随后注册并定义solver的指针(见第2节)

      shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param))
    • 1
    • 2
    • 1
    • 2
  • 调用solverSolver()方法。多个GPU涉及到GPU间带异步处理问题(见第3节)

    if (gpus.size() > 1) {    caffe::P2PSync<float> sync(solver, NULL, solver->param());    sync.run(gpus);} else {    LOG(INFO) << "Starting Optimization";    solver->Solve();}
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

2 Solver的创建

在1中,Solver的指针solver是通过SolverRegistry::CreateSolver创建的,CreateSolver函数中值得注意带是return registry[type](param)

  // Get a solver using a SolverParameter.  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {    const string& type = param.type();    CreatorRegistry& registry = Registry();    CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type        << " (known types: " << SolverTypeListString() << ")";    return registry[type](param);  }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

其中:

registry是一个map<string,Creator>: typedef std::map<string, Creator> CreatorRegistry
其中Creator是一个函数指针类型: typedef Solver<Dtype>* (*Creator)(const SolverParameter&)
registry[type]为一个函数指针变量,在Lenet5中,此处具体的值为 caffe::Creator_SGDSolver<float>(caffe::SolverParameter const&)
其中Creator_SGDSolver在以下宏中定义,
REGISTER_SOLVER_CLASS(SGD)
该宏完全展开得到的内容为:

template <typename Dtype>                                                    \  Solver<Dtype>* Creator_SGDSolver(                                       \      const SolverParameter& param)                                            \  {                                                                            \    return new SGDSolver<Dtype>(param);                                     \  }                                                                            \  static SolverRegisterer<float> g_creator_f_SGD("SGD", Creator_SGDSolver<float>);    \  static SolverRegisterer<double> g_creator_d_SGD("SGD", Creator_SGDSolver<double>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

从上可以看出,registry[type](param)中实际上调用了SGDSolver带构造方法,事实上,网络是在SGDSolver的构造方法中初始化的。
SGDSolver的定义如下:

template <typename Dtype>class SGDSolver : public Solver<Dtype> { public:  explicit SGDSolver(const SolverParameter& param)      : Solver<Dtype>(param) { PreSolve(); }  explicit SGDSolver(const string& param_file)      : Solver<Dtype>(param_file) { PreSolve(); }......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

SGDSolver继承与Solver<Dtype>,因而new SGDSolver<Dtype>(param)将执行Solver<Dtype>的构造函数,然后调用自身构造函数。整个网络带初始化即在这里面完成(详见本系列博文(三))。

3 Solver::Solve()函数

在这个函数里面,程序执行完网络的完整训练过程。
核心代码如下:

template <typename Dtype>void Solver<Dtype>::Solve(const char* resume_file) {  Step(param_.max_iter() - iter_);  //..    Snapshot();  //..  // some additional display   // ...}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

说明:

  1. 值得关注的代码是Step(),在该函数中,值得了param_.max_iter()轮迭代(10000)
  2. 在Snapshot()中序列化model到文件

4 Solver::Step()函数

template <typename Dtype>void Solver<Dtype>::Step(int iters) {  //10000轮迭代  while (iter_ < stop_iter) {    // 每隔500轮进行一次测试    if (param_.test_interval() && iter_ % param_.test_interval() == 0        && (iter_ > 0 || param_.test_initialization())        && Caffe::root_solver()) {      // 测试网络,实际是执行前向传播计算loss      TestAll();    }    // accumulate the loss and gradient    Dtype loss = 0;    for (int i = 0; i < param_.iter_size(); ++i) {      // 执行反向传播,前向计算损失loss,并计算loss关于权值的偏导      loss += net_->ForwardBackward(bottom_vec);    }    // 平滑loss,计算结果用于输出调试等    loss /= param_.iter_size();    // average the loss across iterations for smoothed reporting    UpdateSmoothedLoss(loss, start_iter, average_loss);    // 通过反向传播计算的偏导更新权值    ApplyUpdate();  }}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

4.1 Solver::TestAll()函数

TestAll()中,调用Test(test_net_id)对每个测试网络test_net(不是训练网络train_net)进行测试。在Lenet中,只有一个测试网络,所以只调用一次Test(0)进行测试。
Test()函数里面做了两件事:

  • 前向计算网络,得到网络损失,见 (Caffe,LeNet)前向计算(五)
  • 通过测试网络的第11层accuracy层,与第12层loss层结果统计accuracy与loss信息。

4.2 Net::ForwardBackward()函数

Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {    Dtype loss;    Forward(bottom, &loss);    Backward();    return loss;  }
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

说明:

  • 前向计算。计算网络损失loss,参考 (Caffe,LeNet)前向计算(五)
  • 反向传播。计算loss关于网络权值的偏导,参考 (Caffe,LeNet)反向传播(六)

4.3 Solver::ApplyUpdate()函数

根据反向传播阶段计算的loss关于网络权值的偏导,使用配置的学习策略,更新网络权值从而完成本轮学习。详见 (Caffe,LeNet)权值更新(七)

5 训练完毕

至此,网络训练优化完成。在第3部分solve()函数中,最后对训练网络与测试网络再执行一轮额外的前行计算求得loss,以进行测试。

阅读全文
0 0