caffe源码学习中-tools/caffe.cpp

来源:互联网 发布:淘宝良心的文具店 编辑:程序博客网 时间:2024/06/05 20:10

声明:参考文章

  • Caffe代码解析(3)
  • caffe运行的大概流程

参考文章已经写的很好,自己做了一下整理和某些地方添加了自己的理解。

感谢博主的无私分享。


DEFINE_string(gpu, "", "Optional; run in GPU mode on given device IDs separated by ','." "Use '-gpu all' to run on all available GPUs. The effective training " "batch size is multiplied by the number of devices.");//各种宏定义:使用方式为DEFINE_xxx(name, default_value, instruction);这样就定义了一个xxx类型名为FLAGS_name的标志,如果用户没有在Command Line中提供其值,那么会默认为default_value,instruction是这个标志含义的说明。因此,上面的代码定义了一个string类型的名为FLAGS_gpu的标志,如果在Command Line中用户没有提供值,那么会默认为空字符串,根据说明可以得知这个标志是提供给用户来指定caffe将使用的GPU的。其余的定义也是类似的。typedef int (*BrewFunction)();//声明指向函数的指针类型BrewFunctiontypedef std::map BrewMap;//声明键值对分别为字符串和函数指针的BrewMap类型BrewMap g_brew_map;//声明BrewMap对象g_brew_map//RegisterBrewFunction这个宏在每一个实现主要功能的函数之后将这个函数的名字和其对应的函数指针添加到了g_brew_map中#define RegisterBrewFunction(func) \namespace { \class __Registerer_##func { \ public: /* NOLINT */ \ __Registerer_##func() { \ g_brew_map[#func] = &func; \ } \}; \__Registerer_##func g_registerer_##func; \}//通过GetBrewFunction得到了我们需要调用的那个函数的函数指针,并完成了调用。static BrewFunction GetBrewFunction(const caffe::string& name) { if (g_brew_map.count(name)) { return g_brew_map[name]; } else { LOG(ERROR) << "Available caffe actions:"; for (BrewMap::iterator it = g_brew_map.begin(); it != g_brew_map.end(); ++it) { LOG(ERROR) << "\t" << it->first; } LOG(FATAL) << "Unknown action: " << name; return NULL; // not reachable, just to suppress old compiler warnings. }}

caffe::SolverAction::Enum GetRequestedAction(
    const std::string& flag_value) {
  if (flag_value == "stop") {
    return caffe::SolverAction::STOP;
  }
  if (flag_value == "snapshot") {
    return caffe::SolverAction::SNAPSHOT;
  }
  if (flag_value == "none") {
    return caffe::SolverAction::NONE;
  }
  LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified";
}

//其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。

//train函数int train() { CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";//使用了glog的CHECK_GT宏(含义为check greater than),检查FLAGS_solver的size是否大于0,应该提供对应的solver定义文件的路径CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) << "Give a snapshot to resume training or weights to finetune " "but not both.";//确保用户没有同时提供snapshot和weights参数,这两个参数都是继续之前的训练或者进行fine-tuning的 vector stages = get_stages_from_flags(); caffe::SolverParameter solver_param; caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//SolverParameter的声明和解析代码,SolverParameter是通过Google Protocol Buffer自动生成的一个类//SolverParameter是通过ReadSolverParamsFromTextFileOrDie解析的//函数的实现在/src/caffe/util/upgrade_proto.cpp里:先后调用了两个函数,首先是ReadProtoFromTextFile,这个函数的作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/src/caffe/util/io.cpp的开始(打开了文件,并且读取到了一个FileInputStream的指针中,然后通过protobuf的TextFormat::Parse函数完成了解析)。然后UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容处理,主要的问题就是在旧版本中Solver的type是enum类型,而新版本的变为了string solver_param.mutable_train_state()->set_level(FLAGS_level); for (int i = 0; i < stages.size(); i++) { solver_param.mutable_train_state()->add_stage(stages[i]); } // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. if (FLAGS_gpu.size() == 0 && solver_param.has_solver_mode() && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { if (solver_param.has_device_id()) { FLAGS_gpu = "" + boost::lexical_cast(solver_param.device_id()); } else { // Set default GPU if unspecified FLAGS_gpu = "" + boost::lexical_cast(0); } }//首先是判断用户在Command Line中是否输入了gpu相关的参数//如果没有但是用户在solver的prototxt定义中提供了相关的参数,那就把相关的参数放到FLAGS_gpu中//如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式,但是没有指明具体的gpu_id,那么就默认设置为0。 vector gpus; get_gpus(&gpus); if (gpus.size() == 0) { LOG(INFO) << "Use CPU."; Caffe::set_mode(Caffe::CPU); } else { ostringstream s; for (int i = 0; i < gpus.size(); ++i) { s << (i ? ", " : "") << gpus[i]; } LOG(INFO) << "Using GPUs " << s.str();#ifndef CPU_ONLY cudaDeviceProp device_prop; for (int i = 0; i < gpus.size(); ++i) { cudaGetDeviceProperties(&device_prop, gpus[i]); LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name; }#endif solver_param.set_device_id(gpus[0]); Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); Caffe::set_solver_count(gpus.size()); }//通过一个get_gpus的函数,将存放在FLAGS_gpu中的string转成了一个vector,并完成了具体的设置。 caffe::SignalHandler signal_handler( GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect)); //FLAGS_sigint_effect和FLAGS_sighup_effect是通过gflags定义和解析的两个Command Line Interface的输入参数,分别对应遇到sigint和sighup信号的处理方式,如果用户不设定,sigint的默认值为”stop”,sighup的默认值为”snapshot”。GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler类型的对象signal_handler。

//总结起来,我们通过定义一个SignalHandler类型的对象,告知系统在遇到系统信号的时候回调handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,这个函数实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回我们之前设置的处理方式(SolverAction::Enum类型)。剩余的具体处理再交给Solver的其它函数,

shared_ptr<caffe::Solver > solver(caffe::SolverRegistry::CreateSolver(solver_param));//声明并通过SolverRegistry这个类中的静态函数CreateSolver初始化了一个指向Solver类型的shared_ptr,创建solver对象。 solver->SetActionFunction(signal_handler.GetActionFunction());//通过这个shared_ptr指明了在遇到系统信号(用户按了ctrl+c或者关闭了当前的terminal)时的处理方式。 if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) { CopyLayers(solver.get(), FLAGS_weights); }//判断了一下用户是否定义了snapshot或者weights这两个参数中的一个//如果定义了则需要通过Solver提供的接口从snapshot或者weights文件中去读取已经训练好的网络的参数 LOG(INFO) << "Starting Optimization"; if (gpus.size() > 1) {#ifdef USE_NCCL caffe::NCCL nccl(solver); nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);//如果用户设置了要使用多个gpu,那么要声明一个P2PSync类型的对象,并通过这个对象来完成多gpu的计算#else LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL";#endif } else { solver->Solve();//而如果是只使用单个gpu,那么就通过Solver的Solve()开始具体的优化过程 } LOG(INFO) << "Optimization Done."; return 0;}RegisterBrewFunction(train);int main(int argc, char** argv) { * caffe::GlobalInit(&argc, &argv);//调用src/caffe/common.cpp中的GlobalInit(int* pargc, char*** pargv) 函数:解析命令行的参数* if (argc == 2) {#ifdef WITH_PYTHON_LAYER try {#endif* return GetBrewFunction(caffe::string(argv[1]))();//通过GetBrewFunction得到了我们需要调用的那个函数的函数指针,并完成了调用:eg.train()函数流程大概就是新建一个Solver对象,然后调用Solver类的构造函数,然后在Solver的构造函数中又会新建Net类实例,在Net类的构造函数中又会新建各个Layer的实例,一直具体到设置每个Blob,大概就介绍完了网络初始化的工作*#ifdef WITH_PYTHON_LAYER } catch (bp::error_already_set) { PyErr_Print(); return 1; }#endif } else { gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe"); }}

0 0
原创粉丝点击