Caffe的Command Line Interfaces解析

来源:互联网 发布:购买淘宝店铺价格 编辑:程序博客网 时间:2024/05/16 05:21

1.Google Flags

caffe的Command Line Interfaces一共提供了四个功能:train/test/time/device_query,而Interfaces的输入除了这四种功能还可以输入诸如-solver/-weights/-snapshot/-gpu等参数。这些参数的解析是通过Google Flags工具来完成的。

//caffe/tools/caffe.cppDEFINE_string(gpu, "",    "Optional; run in GPU mode on given device IDs separated by ','."    "Use '-gpu all' to run on all available GPUs. The effective training "    "batch size is multiplied by the number of devices."    );

这个宏的使用方式为DEFINE_xxx(name, default_value, instruction);。定义了一个xxx类型名为FLAGS_name的标志,如果用户没有在Command Line中提供其值,那么会默认为default_value,instruction是这个标志含义的说明。因此,上面的代码定义了一个string类型的名为FLAGS_gpu的标志,如果在Command Line中用户没有提供值,那么会默认为空字符串,根据说明可以得知这个标志是提供给用户来指定caffe将使用的GPU的。
解析这些标志的代码在caffe.cpp中的main()中调用了/caffe/src/common.cpp中的GlobalInit(&argc, &argv)函数:

void GlobalInit(int* pargc, char*** pargv) {   // Google flags.   ::gflags::ParseCommandLineFlags(pargc, pargv, true);   // Google logging.   ::google::InitGoogleLogging(*(pargv)[0]);   // Provide a backtrace on segfault.   ::google::InstallFailureSignalHandler();}

::gflags::ParseCommandLineFlags(pargc, pargv, true)函数是Google Flags用来解析输入的参数的,前两个参数分别是指向main()的argc和argv的指针,第三个参数为true,表示在解析完所有的标志之后将这些标志从argv中清除,因此在解析完成之后,argc的值为2,argv[0]为main,argv[1]为train/test/time/device_query中的一个。

2.Register Brew Function

Caffe在Command Line Interfaces中一共提供了4种功能:train/test/time/device_query,分别对应着四个函数,这四个函数的调用是通过一个叫做g_brew_map的全局变量来完成。

// A simple registry for caffe commands.typedef int (*BrewFunction)();typedef std::map<caffe::string, BrewFunction> BrewMap;BrewMap g_brew_map;

g_brew_map是一个key为string类型,value为BrewFunction类型的一个map类型的全局变量,BrewFunction是一个函数指针类型,指向的是参数为空,返回值为int的函数,也就是train/test/time/device_query这四个函数的类型。
在train等四个函数实现的后面都紧跟着宏的调用:RegisterBrewFunction(train),其作用是定义了一个名为__Register_train的类,在定义完这个类之后,定义了一个这个类的变量,调用构造函数并在g_brew_map中添加了key为”train”,value为指向train函数的指针的一个元素。

#define RegisterBrewFunction(func) \  namespace { \  class __Registerer_##func { \   public: /* NOLINT */ \    __Registerer_##func() { \      g_brew_map[#func] = &func; \    } \  }; \  __Registerer_##func g_registerer_##func; \}

在完成初始化(GlobalInit)之后,在main()函数中是通过以下方式进行调用RegisterBrewFunction函数:

// main()中的调用代码return GetBrewFunction(caffe::string(argv[1]))();

例如以train函数为例,在Command Line中输入了caffe train <args>,经过Google Flags的解析argv[1]=train,因此,在GetBrewFunction中会通过g_brew_map返回一个指向train函数的函数指针,最后在main函数中就通过这个返回的函数指针完成了对train函数的调用。
BrewFunction中的具体实现

// BrewFunction的具体实现static BrewFunction GetBrewFunction(const caffe::string& name) {   if (g_brew_map.count(name))    {     return g_brew_map[name];   }    else    {     LOG(ERROR) << "Available caffe actions:";     for (BrewMap::iterator it = g_brew_map.begin();it != g_brew_map.end(); ++it)      {       LOG(ERROR) << "\t" << it->first;     }     LOG(FATAL) << "Unknown action: " << name;     return NULL;  // not reachable, just to suppress old compiler warnings.   }}

RegisterBrewFunction这个宏在每一个实现主要功能的函数之后将这个函数的名字和其对应的函数指针添加到了g_brew_map中,然后在main函数中,通过GetBrewFunction得到了所需要调用的函数的指针,并完成调用。

3.train()函数

 CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())     << "Give a snapshot to resume training or weights to finetune "     "but not both.";

第一行使用了glog的CHECK_GT宏(含义为check greater than),检查FLAGS_solver的size是否大于0,如果小于或等于0则输出提示:”Need a solver definition to train”。
FLAGS_solver是最开始通过DEFINE_string定义的标志,如果希望训练一个模型,那么首先提供对应的定义solver文件的路径,该行代码的目的是确保提供了这样的路径。与第一行代码类似,第二行代码是确保用户没有同时提供snapshot和weights参数,这两个参数都是继续之前的训练或者进行fine-tuning的,如果同时指明了这两个标志,则不知道到应该从哪个路径的文件中读入模型的相关参数更为合适。

caffe::SolverParameter solver_param;caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

SolverParameter是通过Google Protocol Buffer自动生成的一个类,是通过ReadSolverParamsFromTextFileOrDie来完成解析的,这个函数的实现在/caffe/src/caffe/util/upgrade_proto.cpp中。

// Read parameters from a file into a SolverParameter proto message.void ReadSolverParamsFromTextFileOrDie(const string& param_file,SolverParameter* param) {  CHECK(ReadProtoFromTextFile(param_file, param))      << "Failed to parse SolverParameter file: " << param_file;  UpgradeSolverAsNeeded(param_file, param);}
  • 上述函数中首先调用ReadProtoFromTextFile,其作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/caffe/src/caffe/util/io.cpp的开始:
   bool ReadProtoFromTextFile(const char* filename, Message* proto)    {          int fd = open(filename, O_RDONLY);       CHECK_NE(fd, -1) << "File not found: " << filename;          FileInputStream* input = new FileInputStream(fd);          bool success = google::protobuf::TextFormat::Parse(input, proto);          delete input;       close(fd);       return success;  }  
  • 其次调用UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容处理:
// Check for deprecations and upgrade the SolverParameter as needed.bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {   bool success = true;   // Try to upgrade old style solver_type enum fields into new string type   if (SolverNeedsTypeUpgrade(*param))    {     LOG(INFO) << "Attempting to upgrade input file specified using deprecated "               << "'solver_type' field (enum)': " << param_file;     if (!UpgradeSolverType(param))      {       success = false;       LOG(ERROR) << "Warning: had one or more problems upgrading "                  << "SolverType (see above).";     }     else     {       LOG(INFO) << "Successfully upgraded file specified using deprecated "                 << "'solver_type' field (enum) to 'type' field (string).";       LOG(WARNING) << "Note that future Caffe releases will only support "                    << "'type' field (string) for a solver's type.";     }   }   return success;}

主要的问题就是在旧版本中Solver的type是enum类型,而新版本的变为了string。

 // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. if (FLAGS_gpu.size() == 0     && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU)  {     if (solver_param.has_device_id())      {         FLAGS_gpu = ""  +             boost::lexical_cast<string>(solver_param.device_id());     }     else      {  // Set default GPU if unspecified         FLAGS_gpu = "" + boost::lexical_cast<string>(0);     } } vector<int> gpus; get_gpus(&gpus); if (gpus.size() == 0)  {   LOG(INFO) << "Use CPU.";   Caffe::set_mode(Caffe::CPU); }  else  {   ostringstream s;   for (int i = 0; i < gpus.size(); ++i)    {     s << (i ? ", " : "") << gpus[i];   }   LOG(INFO) << "Using GPUs " << s.str();   solver_param.set_device_id(gpus[0]);   Caffe::SetDevice(gpus[0]);   Caffe::set_mode(Caffe::GPU);   Caffe::set_solver_count(gpus.size()); }

get_gpus函数的定义如下:

// Parse GPU ids or use all available devicesstatic void get_gpus(vector<int>* gpus) {  if (FLAGS_gpu == "all")   {    int count = 0;    #ifndef CPU_ONLY        CUDA_CHECK(cudaGetDeviceCount(&count));    #else        NO_GPU;    #endif    for (int i = 0; i < count; ++i)     {      gpus->push_back(i);    }  }   else if (FLAGS_gpu.size())   {    vector<string> strings;    boost::split(strings, FLAGS_gpu, boost::is_any_of(","));    for (int i = 0; i < strings.size(); ++i)     {      gpus->push_back(boost::lexical_cast<int>(strings[i]));    }  }   else   {    CHECK_EQ(gpus->size(), 0);  }}

以上代码根据用户的设置来选择caffe工作的模式(GPU或CPU)以及使用哪些GPU(caffe已经支持了多GPU同时工作)。首先是判断用户在Command Line中是否输入了gpu相关的参数,如果没有(FLAGS_gpu.size()==0)但是用户在solver的prototxt定义中提供了相关的参数,那就把相关的参数放到FLAGS_gpu中,如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式,但是没有指明具体的gpu_id,那么就默认设置为0。

caffe::SignalHandler signal_handler(       GetRequestedAction(FLAGS_sigint_effect),       GetRequestedAction(FLAGS_sighup_effect));shared_ptr<caffe::Solver<float>>     solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));solver->SetActionFunction(signal_handler.GetActionFunction());

以上代码通过SolverRegistry初始化了一个指向Solver类型的shared_ptr。并通过这个shared_ptr指明了在遇到系统信号(用户按了ctrl+c或者关闭了当前的terminal)时的处理方式。

if (FLAGS_snapshot.size()) {   LOG(INFO) << "Resuming from " << FLAGS_snapshot;   solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) {   CopyLayers(solver.get(), FLAGS_weights);}

上述代码判断了一下用户是否定义了snapshot或者weights这两个参数中的一个,如果定义了则需要通过Solver提供的接口从snapshot或者weights文件中去读取已经训练好的网络的参数。

if (gpus.size() > 1) {  caffe::P2PSync<float> sync(solver, NULL, solver->param());  sync.run(gpus);} else {  LOG(INFO) << "Starting Optimization";  solver->Solve();}LOG(INFO) << "Optimization Done.";return 0;

最后,如果用户设置了要使用多个gpu,那么要声明一个P2PSync类型的对象,并通过这个对象来完成多gpu的计算,这一部分的代码,这一系列的文章会暂时先不涉及。而如果是只使用单个gpu,那么就通过Solver的Solve()开始具体的优化过程。在优化结束之后,函数将0值返回给main函数,整个train过程到这里也就结束了。

0 0
原创粉丝点击