Caffe的Command Line Interfaces解析

来源:互联网 发布:购买淘宝店铺价格 编辑:程序博客网 时间:2024/05/16 05:21

1.Google Flags

caffe的Command Line Interfaces一共提供了四个功能:train/test/time/device_query,而Interfaces的输入除了这四种功能还可以输入诸如-solver/-weights/-snapshot/-gpu等参数。这些参数的解析是通过Google Flags工具来完成的。

//caffe/tools/caffe.cppDEFINE_string(gpu, "",    "Optional; run in GPU mode on given device IDs separated by ','."    "Use '-gpu all' to run on all available GPUs. The effective training "    "batch size is multiplied by the number of devices."    );

这个宏的使用方式为DEFINE_xxx(name, default_value, instruction);。定义了一个xxx类型名为FLAGS_name的标志,如果用户没有在Command Line中提供其值,那么会默认为default_value,instruction是这个标志含义的说明。因此,上面的代码定义了一个string类型的名为FLAGS_gpu的标志,如果在Command Line中用户没有提供值,那么会默认为空字符串,根据说明可以得知这个标志是提供给用户来指定caffe将使用的GPU的。
解析这些标志的代码在caffe.cpp中的main()中调用了/caffe/src/common.cpp中的GlobalInit(&argc, &argv)函数:

void GlobalInit(int* pargc, char*** pargv) {   // Google flags.   ::gflags::ParseCommandLineFlags(pargc, pargv, true);   // Google logging.   ::google::InitGoogleLogging(*(pargv)[0]);   // Provide a backtrace on segfault.   ::google::InstallFailureSignalHandler();}

::gflags::ParseCommandLineFlags(pargc, pargv, true)函数是Google Flags用来解析输入的参数的,前两个参数分别是指向main()的argc和argv的指针,第三个参数为true,表示在解析完所有的标志之后将这些标志从argv中清除,因此在解析完成之后,argc的值为2,argv[0]为main,argv[1]为train/test/time/device_query中的一个。

2.Register Brew Function

Caffe在Command Line Interfaces中一共提供了4种功能:train/test/time/device_query,分别对应着四个函数,这四个函数的调用是通过一个叫做g_brew_map的全局变量来完成。

// A simple registry for caffe commands.typedef int (*BrewFunction)();typedef std::map<caffe::string, BrewFunction> BrewMap;BrewMap g_brew_map;


#define RegisterBrewFunction(func) \  namespace { \  class __Registerer_##func { \   public: /* NOLINT */ \    __Registerer_##func() { \      g_brew_map[#func] = &func; \    } \  }; \  __Registerer_##func g_registerer_##func; \}


// main()中的调用代码return GetBrewFunction(caffe::string(argv[1]))();

例如以train函数为例,在Command Line中输入了caffe train <args>,经过Google Flags的解析argv[1]=train,因此,在GetBrewFunction中会通过g_brew_map返回一个指向train函数的函数指针,最后在main函数中就通过这个返回的函数指针完成了对train函数的调用。

// BrewFunction的具体实现static BrewFunction GetBrewFunction(const caffe::string& name) {   if (g_brew_map.count(name))    {     return g_brew_map[name];   }    else    {     LOG(ERROR) << "Available caffe actions:";     for (BrewMap::iterator it = g_brew_map.begin();it != g_brew_map.end(); ++it)      {       LOG(ERROR) << "\t" << it->first;     }     LOG(FATAL) << "Unknown action: " << name;     return NULL;  // not reachable, just to suppress old compiler warnings.   }}



 CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())     << "Give a snapshot to resume training or weights to finetune "     "but not both.";

第一行使用了glog的CHECK_GT宏(含义为check greater than),检查FLAGS_solver的size是否大于0,如果小于或等于0则输出提示:”Need a solver definition to train”。

caffe::SolverParameter solver_param;caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

SolverParameter是通过Google Protocol Buffer自动生成的一个类,是通过ReadSolverParamsFromTextFileOrDie来完成解析的,这个函数的实现在/caffe/src/caffe/util/upgrade_proto.cpp中。

// Read parameters from a file into a SolverParameter proto message.void ReadSolverParamsFromTextFileOrDie(const string& param_file,SolverParameter* param) {  CHECK(ReadProtoFromTextFile(param_file, param))      << "Failed to parse SolverParameter file: " << param_file;  UpgradeSolverAsNeeded(param_file, param);}
  • 上述函数中首先调用ReadProtoFromTextFile,其作用是从param_file这个路径去读取solver的定义,并将文件中的内容解析存到param这个指针指向的对象,具体的实现在/caffe/src/caffe/util/io.cpp的开始:
   bool ReadProtoFromTextFile(const char* filename, Message* proto)    {          int fd = open(filename, O_RDONLY);       CHECK_NE(fd, -1) << "File not found: " << filename;          FileInputStream* input = new FileInputStream(fd);          bool success = google::protobuf::TextFormat::Parse(input, proto);          delete input;       close(fd);       return success;  }  
  • 其次调用UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容处理:
// Check for deprecations and upgrade the SolverParameter as needed.bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {   bool success = true;   // Try to upgrade old style solver_type enum fields into new string type   if (SolverNeedsTypeUpgrade(*param))    {     LOG(INFO) << "Attempting to upgrade input file specified using deprecated "               << "'solver_type' field (enum)': " << param_file;     if (!UpgradeSolverType(param))      {       success = false;       LOG(ERROR) << "Warning: had one or more problems upgrading "                  << "SolverType (see above).";     }     else     {       LOG(INFO) << "Successfully upgraded file specified using deprecated "                 << "'solver_type' field (enum) to 'type' field (string).";       LOG(WARNING) << "Note that future Caffe releases will only support "                    << "'type' field (string) for a solver's type.";     }   }   return success;}


 // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. if (FLAGS_gpu.size() == 0     && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU)  {     if (solver_param.has_device_id())      {         FLAGS_gpu = ""  +             boost::lexical_cast<string>(solver_param.device_id());     }     else      {  // Set default GPU if unspecified         FLAGS_gpu = "" + boost::lexical_cast<string>(0);     } } vector<int> gpus; get_gpus(&gpus); if (gpus.size() == 0)  {   LOG(INFO) << "Use CPU.";   Caffe::set_mode(Caffe::CPU); }  else  {   ostringstream s;   for (int i = 0; i < gpus.size(); ++i)    {     s << (i ? ", " : "") << gpus[i];   }   LOG(INFO) << "Using GPUs " << s.str();   solver_param.set_device_id(gpus[0]);   Caffe::SetDevice(gpus[0]);   Caffe::set_mode(Caffe::GPU);   Caffe::set_solver_count(gpus.size()); }


// Parse GPU ids or use all available devicesstatic void get_gpus(vector<int>* gpus) {  if (FLAGS_gpu == "all")   {    int count = 0;    #ifndef CPU_ONLY        CUDA_CHECK(cudaGetDeviceCount(&count));    #else        NO_GPU;    #endif    for (int i = 0; i < count; ++i)     {      gpus->push_back(i);    }  }   else if (FLAGS_gpu.size())   {    vector<string> strings;    boost::split(strings, FLAGS_gpu, boost::is_any_of(","));    for (int i = 0; i < strings.size(); ++i)     {      gpus->push_back(boost::lexical_cast<int>(strings[i]));    }  }   else   {    CHECK_EQ(gpus->size(), 0);  }}

以上代码根据用户的设置来选择caffe工作的模式(GPU或CPU)以及使用哪些GPU(caffe已经支持了多GPU同时工作)。首先是判断用户在Command Line中是否输入了gpu相关的参数,如果没有(FLAGS_gpu.size()==0)但是用户在solver的prototxt定义中提供了相关的参数,那就把相关的参数放到FLAGS_gpu中,如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式,但是没有指明具体的gpu_id,那么就默认设置为0。

caffe::SignalHandler signal_handler(       GetRequestedAction(FLAGS_sigint_effect),       GetRequestedAction(FLAGS_sighup_effect));shared_ptr<caffe::Solver<float>>     solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));solver->SetActionFunction(signal_handler.GetActionFunction());


if (FLAGS_snapshot.size()) {   LOG(INFO) << "Resuming from " << FLAGS_snapshot;   solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) {   CopyLayers(solver.get(), FLAGS_weights);}


if (gpus.size() > 1) {  caffe::P2PSync<float> sync(solver, NULL, solver->param());;} else {  LOG(INFO) << "Starting Optimization";  solver->Solve();}LOG(INFO) << "Optimization Done.";return 0;


0 0