Caffe中Solver解析
来源:互联网 发布:co2激光打标机软件图 编辑:程序博客网 时间:2024/05/16 15:05
1.Solver的初始化
shared_ptr<caffe::Solver<float>> solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));
caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>
的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。
具体步骤:
(1)SolverRegistry::CreateSolver(solver_param)
。
(2)通过static的g_registry_[type]获取type对应的Solver的Creator函数指针。
(3)调用Creator函数。
(4)new SGDSolver<Dtype>(solver_param)
创建solver。
SolverRegistry类源码:
class SolverRegistry { public: typedef Solver<Dtype>* (*Creator)(const SolverParameter&); typedef std::map<string, Creator> CreatorRegistry; static CreatorRegistry& Registry() { static CreatorRegistry* g_registry_ = new CreatorRegistry(); return *g_registry_; } static void AddCreator(const string& type, Creator creator) { CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 0) << "Solver type " << type << " already registered."; registry[type] = creator; } static Solver<Dtype>* CreateSolver(const SolverParameter& param) { const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type << " (known types: " << SolverTypeListString() << ")"; return registry[type](param); } static vector<string> SolverTypeList() { CreatorRegistry& registry = Registry(); vector<string> solver_types; for (typename CreatorRegistry::iterator iter = registry.begin();iter != registry.end(); ++iter) { solver_types.push_back(iter->first); } return solver_types; } private: SolverRegistry() {} static string SolverTypeListString() { vector<string> solver_types = SolverTypeList(); string solver_types_str; for (vector<string>::iterator iter = solver_types.begin();iter != solver_types.end(); ++iter) { if (iter != solver_types.begin()) { solver_types_str += ", "; } solver_types_str += *iter; } return solver_types_str; }};
SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。
CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*
。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*
返回。
Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。
Register的具体步骤:
(1)Registry_Solver_Class(SGD)。
(2)定义Creator函数,Registry_Solver_Creator。
(3)定义SolverRegistry<float>
类型的static变量,定义SolverRegistry<double>
类型的static变量。
(4)SolverRegistry::AddCreator将定义的Creator函数指针添加到static的变量g_registry_(map)中。
SolverRegisterer源码:
template <typename Dtype>class SolverRegisterer { public: SolverRegisterer(const string& type, Solver<Dtype>* (*creator)(const SolverParameter&));};#define REGISTER_SOLVER_CREATOR(type, creator) \ static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \ static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \#define REGISTER_SOLVER_CLASS(type) \ template <typename Dtype> \ Solver<Dtype>* Creator_##type##Solver( \ const SolverParameter& param) \ { \ return new type##Solver<Dtype>(param); \ } \ REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)}#endif
在sgd_solver.cpp文件末尾有REGISTER_SOLVER_CLASS(SGD)
,使用REGISTER_SOLVER_CLASS宏定义一个名为Creator_SGDSolver的函数,即为Creator类型的指针函数,在Creator_SGDSolver函数中调用了SGDSolver的构造函数,并返回所构造的指针变量。Creator类型的指针函数的作用:构造一个对应类型的Solver对象,将其指针返回,然后在REGISTER_SOLVER_CLASS宏里又调用了REGISTER_SOLVER_CREATOR宏,该宏调用相对应(分别定义了SolverRegisterer类模板的float和double类型的static变量)的构造函数。在SolverRegisterer的构造函数中调用了SolverRegistry类的AddCreator函数,其功能将刚才定义的Creator_SGDSolver函数的指针存到g_registry所指向的map中。类似地,所有的Solver对应的cpp文件的末尾都调用了REGISTER_SOLVER_CLASS宏来完成注册,在所有的Solver都注册之后就可以通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。
2.SIGINT和SIGHUP信号的处理
Caffe在train或者test的过程中都有可能会遇到系统信号(用户按下ctrl+c或者关掉了控制的terminal),可以通过对sigint_effect和sighup_effect来设置遇到系统信号的时候希望进行的处理方式:
caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT
在caffe.cpp中定义了一个GetRequesedAction函数来将设置的string类型的标志转变为枚举类型的变量:
caffe::SolverAction::Enum GetRequestedAction(const std::string& flag_value) { if (flag_value == "stop") { return caffe::SolverAction::STOP; } if (flag_value == "snapshot") { return caffe::SolverAction::SNAPSHOT; } if (flag_value == "none") { return caffe::SolverAction::NONE; } LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; } // SolverAction::Enum的定义 namespace SolverAction { enum Enum { NONE = 0, // Take no special action. STOP = 1, // Stop training. snapshot_after_train controls whether a // snapshot is created. SNAPSHOT = 2 // Take a snapshot, and keep training. };}
其中SolverAction::Enum的定义在solver.hpp中,这是一个定义为枚举类型的数据类型,只有三个可能的值,分别对应了三种处理系统信号的方式:NONE(忽略信号什么都不做)/STOP(停止训练)/SNAPSHOT(保存当前的训练状态,继续训练)。在caffe.cpp中的train函数里Solver设置如何处理系统信号的代码为:
caffe::SignalHandler signal_handler( GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect) ); solver->SetActionFunction(signal_handler.GetActionFunction());
通过gflags定义和解析的两个Command Line Interface的输入参数,FLAGS_sigint_effect和FLAGS_sighup_effect分别对应遇到sigint和sighup信号的处理方式,如果用户不设定,sigint的默认值为stop,sighup的默认值为snapshot。GetRequestedAction函数会将string类型的FLAGS_xx转为SolverAction::Enum类型,并用来定义一个SignalHandler类型的对象signal_handler。这部分代码都依赖于SignalHandler这个类的接口:
// header fileclass SignalHandler { public: // Contructor. Specify what action to take when a signal is received. SignalHandler(SolverAction::Enum SIGINT_action, SolverAction::Enum SIGHUP_action); ~SignalHandler(); ActionCallback GetActionFunction(); private: SolverAction::Enum CheckForSignals() const; SolverAction::Enum SIGINT_action_; SolverAction::Enum SIGHUP_action_;};
// source fileSignalHandler::SignalHandler(SolverAction::Enum SIGINT_action, SolverAction::Enum SIGHUP_action): SIGINT_action_(SIGINT_action),SIGHUP_action_(SIGHUP_action) { HookupHandler();}void HookupHandler() { if (already_hooked_up) { LOG(FATAL) << "Tried to hookup signal handlers more than once."; } already_hooked_up = true; struct sigaction sa; sa.sa_handler = &handle_signal; // ...}static volatile sig_atomic_t got_sigint = false;static volatile sig_atomic_t got_sighup = false;void handle_signal(int signal) { switch (signal) { case SIGHUP: got_sighup = true; break; case SIGINT: got_sigint = true; break; }}ActionCallback SignalHandler::GetActionFunction() { return boost::bind(&SignalHandler::CheckForSignals, this);}SolverAction::Enum SignalHandler::CheckForSignals() const { if (GotSIGHUP()) { return SIGHUP_action_;} if (GotSIGINT()) { return SIGINT_action_;} return SolverAction::NONE;}bool GotSIGINT() { bool result = got_sigint; got_sigint = false; return result;}bool GotSIGHUP() { bool result = got_sighup; got_sighup = false; return result;}// ActionCallback的含义typedef boost::function<SolverAction::Enum()> ActionCallback;
- SignalHandler类有两个数据成员,都是SolverAction::Enum类型的,分别对应sigint和sighup信号,在构造函数中,用解析FLAGS_xx得到的结果分别给两个成员赋值,然后调用了HookupHandler函数,这个函数的主要作用是定义了一个sigaction类型(应该是系统级别的代码)的对象sa,然后通过sa.sa_handler= &handle_signal来设置,当有遇到系统信号时,调用handle_signal函数来处理,即判断一下当前的信号是什么类型,如果是sigint就将全局的static变量got_sigint变为true,sighup的处理类似。
- 在根据用户设置(或者默认值)的参数定义了signal_handler之后,solver通过SetActionFunction来设置了如何处理系统信号。这个函数的输入为signal_handler的GetActionFunction的返回值,根据上述的代码可以看到,GetActionFunction会返回signal_handler对象的CheckForSignals函数的地址(boost::bind的具体使用请参考boost官方文档)。而在Solver的SetActionFunction函数中只是简单的把Solver的一个成员action_request_function_赋值为输入参数的值,以当前的例子来说就是,solver对象的action_request_function_指向了signal_handler对象的CheckForSignals函数的地址。其中的ActionCallback是一个函数指针类型,指向了参数为空,返回值为SolverAction::Enum类型的函数(boost::function具体用法参考官方文档)。
- 总之,通过定义一个SignalHandler类型的对象,告知系统在遇到系统信号的时候回调handle_signal函数来改变全局变量got_sigint和got_sighup的值,然后通过Solver的接口设置了其遇到系统函数将调用signal_handler的Check函数,实际上就是去判断当前是否遇到了系统信号,如果遇到某个类型的信号,就返回设置的处理方式(SolverAction::Enum类型)。
3.Solver::Solve()具体实现
Solver::Solve源码分析:
void Solver<Dtype>::Solve(const char* resume_file) { // 检查当前是否是root_solver(多GPU模式下,只有root_solver才运行这一部分的代码) CHECK(Caffe::root_solver()); // 输出learning policy(更新学习率的策略) LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // requested_early_exit_初始值为false,此时不要求在优化结束前退出 requested_early_exit_ = false; // 判断指针resume_file是否NULL,如果不是则从resume_file存储的路径里读取之前训练的状态 if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; Restore(resume_file); } // 调用了Step函数,其执行了实际的逐步的迭代过程 Step(param_.max_iter() - iter_); // 迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束之后snapshot,可在solver.prototxt里设置 if (param_.snapshot_after_train() && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } // 如果在Step函数的迭代过程中遇到了系统信号,且处理方式设置为STOP,requested_early_exit_会被修改为true,迭代提前结束,并输出相关信息 if (requested_early_exit_) { LOG(INFO) << "Optimization stopped early."; return; } // 判断是否需要输出最后的loss if (param_.display() && iter_ % param_.display() == 0) { Dtype loss; net_->ForwardPrefilled(&loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; } // 判断是否需要最后Test if (param_.test_interval() && iter_ % param_.test_interval() == 0) { TestAll(); } LOG(INFO) << "Optimization Done."; }
Solver::Step函数源码解析:
template <typename Dtype> void Solver<Dtype>::Step(int iters) { vector<Blob<Dtype>*> bottom_vec; // 设置开始的迭代次数(如果是从之前的snapshot恢复的,那iter_等于snapshot时的迭代次数)和结束的迭代次数 const int start_iter = iter_; const int stop_iter = iter_ + iters; // 输出的loss为前average_loss次loss的平均值,在solver.prototxt里设置,默认为1, // losses存储前average_loss个loss,smoothed_loss为最后要输出的均值 int average_loss = this->param_.average_loss(); vector<Dtype> losses; Dtype smoothed_loss = 0; // 迭代 while (iter_ < stop_iter) { // 清空上一次所有参数的梯度 net_->ClearParamDiffs(); // 判断是否需要测试 if (param_.test_interval() && iter_ % param_.test_interval() == 0 && (iter_ > 0 || param_.test_initialization()) && Caffe::root_solver()) { TestAll(); // 判断是否需要提前结束迭代 if (requested_early_exit_) { break; } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_start(); } // 判断当前迭代次数是否需要显示loss等信息 const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); Dtype loss = 0; // iter_size在solver.prototxt中设置,实际上的batch_size=iter_size * batch_size(网络中定义的), // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,这个loss是通过调用Net::ForwardBackward函数得到的 // 在GPU的显存不够的时候设置,比如把batch_size设置为128,但是会out_of_memory // 借助这个方法,可以设置batch_size=32,iter_size=4,那实际上每次迭代还是处理了128个数据 for (int i = 0; i < param_.iter_size(); ++i) { loss += net_->ForwardBackward(bottom_vec); } loss /= param_.iter_size(); // 计算要输出的smoothed_loss,如果losses里还没有存够average_loss个loss则将当前的loss插入,如果已经存够了,则将之前的替换掉 if (losses.size() < average_loss) { losses.push_back(loss); int size = losses.size(); smoothed_loss = (smoothed_loss * (size - 1) + loss) / size; } else { int idx = (iter_ - start_iter) % average_loss; smoothed_loss += (loss - losses[idx]) / average_loss; losses[idx] = loss; } // 输出当前迭代的信息 if (display) { LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ << ", loss = " << smoothed_loss; const vector<Blob<Dtype>*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]]; for (int k = 0; k < result[j]->count(); ++k) { ostringstream loss_msg_stream; if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_gradients_ready(); } // 执行梯度的更新,其在基类Solver中没有实现,但会调用每个子类的实现 ApplyUpdate(); // 迭代次数加1 ++iter_; // 调用GetRequestedAction,实际是通过action_request_function_函数指针调用已设置好(通过SetRequestedAction)的signal_handler的CheckForSignals函数,这个函数的作用是会根据是否遇到系统信号以及信号的类型和用户设置(或者默认)的方式返回处理的方式 SolverAction::Enum request = GetRequestedAction(); // 判断当前迭代是否需要snapshot,如果request==SNAPSHOT则也需要 if ((param_.snapshot() && iter_ % param_.snapshot() == 0 && Caffe::root_solver()) || (request == SolverAction::SNAPSHOT)) { Snapshot(); } // 如果request为STOP则修改requested_early_exit_为true,会提前结束迭代 if (SolverAction::STOP == request) { requested_early_exit_ = true; break; } } }
每一组网络中的参数的更新都是在不同类型的Solver实现各自的ApplyUpdate函数中完成的,以最常用的SGD为例子来分析这个函数具体的功能:
SGDSolver::ApplyUpdate源码分析:
template <typename Dtype>void SGDSolver<Dtype>::ApplyUpdate() { CHECK(Caffe::root_solver()); // GetLearningRate根据设置的lr_policy来计算当前迭代的learning rate的值 Dtype rate = GetLearningRate(); // 判断是否需要输出当前的learning rate if (this->param_.display() && this->iter_ % this->param_.display() == 0) { LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; } // 避免梯度爆炸,如果梯度的二范数超过了某个数值则进行scale操作,将梯度减小 ClipGradients(); // 对所有可更新的网络参数进行操作 for (int param_id = 0; param_id < this->net_->learnable_params().size();++param_id) { // 将第param_id个参数的梯度除以iter_size,其作用是保证实际的batch_size=iter_size * batch_size Normalize(param_id); // 将正则化部分的梯度存入到每个参数的梯度中 Regularize(param_id); // 计算SGD算法的梯度(momentum等) ComputeUpdateValue(param_id, rate); } // 调用Net::Update更新所有的参数 this->net_->Update(); }
Normalize函数源码分析:
template <typename Dtype> void SGDSolver<Dtype>::Normalize(int param_id) { // 如果iter_size的值为1,则不需要任何处理直接return if (this->param_.iter_size() == 1) { return; } // 通过net_返回所有可以学习的参数,是一个vector<shared_ptr<Blob<Dtype>>> const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); // 要乘以的系数等于1/iter_size const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); switch (Caffe::mode()) { case Caffe::CPU: { // caffe_scal在/CAFFE_ROOT/src/caffe/util/math_functions.cpp中 // 是blas的scale函数的一个封装,第一个参数是数据的个数,第二个参数是乘以的系数, // 第三个参数是数据的指针 caffe_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_cpu_diff()); break; } case Caffe::GPU: { // GPU代码略 } } }
Regularize函数源码分析:
template <typename Dtype> void SGDSolver<Dtype>::Regularize(int param_id) { // 获取所有可以学习的参数的vector const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); // 获取所有的参数对应的weight_decay的vector const vector<float>& net_params_weight_decay = this->net_->params_weight_decay(); // 模型整体的weight_decay数值 Dtype weight_decay = this->param_.weight_decay(); // 获取正则化的类型:L1 或 L2 string regularization_type = this->param_.regularization_type(); // 实际的weight_decay等于整体模型的数值乘以具体每个参数的数值 Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; switch (Caffe::mode()) { case Caffe::CPU: { // 如果weight_decay不为0,则计算 if (local_decay) { if (regularization_type == "L2") { // L2的梯度为diff_ = weight_decay*data_ + diff_ // caffe_axpy的功能是 y = a*x + y // 第一个参数是数据的个数,第二个是上式的a,第三个是x的指针,第四个是y的指针 caffe_axpy(net_params[param_id]->count(),local_decay, net_params[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } else if (regularization_type == "L1") { // L1的梯度为diff_ = diff_ + sign(data_) // temp_ = sign(data_) caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), temp_[param_id]->mutable_cpu_data()); // 将temp_加到diff_中 diff_ = weight_decay*temp_ + diff_ caffe_axpy(net_params[param_id]->count(),local_decay, temp_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } else { LOG(FATAL) << "Unknown regularization type: " << regularization_type; } } break; } // GPU代码略 }
ComputeUpdatedValue函数源码分析:
template <typename Dtype>void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { // 获取所有可以更新的参数的vector const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); // 获取所有参数对应的learning_rate的vector const vector<float>& net_params_lr = this->net_->params_lr(); // 获取momentum值 Dtype momentum = this->param_.momentum(); // 实际的learning_rate为全局的learning_rate乘以每个参数对应的learning_rate Dtype local_rate = rate * net_params_lr[param_id]; switch (Caffe::mode()) { case Caffe::CPU: { // 关于SGD的公式参考caffe官网tutorial的Solver部分 // history_存储了上一次的梯度,下面这个函数: // history_ = learning_rate*diff_ + momentum*history caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum,history_[param_id]->mutable_cpu_data()); // 把当前的梯度拷贝给参数Blob的diff_ caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); break; } case Caffe::GPU: { // GPU代码略 } }
- Caffe中Solver解析
- caffe solver参数解析
- caffe中解析器solver中各参数的含义
- caffe中关于solver
- caffe源码解析之solver
- caffe中 solver.prototxt文件
- caffe中 solver.prototxt文件
- caffe源码解析 — solver.cpp
- caffe源码解析 — solver.cpp
- caffe源码解析 — solver.cpp
- Caffe源码解析10:Caffe的求解(Solver)
- Caffe Solver
- caffe中solver.prototxt参数说明
- caffe中solver配置文件的解读
- caffe中solver.prototx 中的参数
- Caffe源码中Solver文件分析
- caffe中solver.prototxt参数说明
- Caffe中求解器(Solver)介绍
- java身份证校验工具类
- Git取消合并(merge)、暂存修改(stash)、回退到某个版本(reset)的使用方法
- oracle单实例12.2.0.1安装
- 欢迎使用CSDN-markdown编辑器
- Okhttp3网络请求框架+MVP设计模式简单实战
- Caffe中Solver解析
- 51 nod 1283 最小周长
- thinkphp+js写的多文件上传
- 统计语言模型(下)
- Git重命名仓库、修改远程仓库地址、修改仓库配置
- UVA
- Android系列之初:学习心得
- 你必须了解的Session的本质
- 教你一天玩转JavaScript(八)——使用JavaScript完成省市联动的效果