【caffe源代码的梳理之三】Net

来源:互联网 发布:玩世不恭知乎 编辑:程序博客网 时间:2024/05/16 08:13

作者:JackGao24 博客园
作者:JackGao16 CSDN
文章链接:http://blog.csdn.net/u013108511/article/details/76636263
邮箱:gshuai16@mail.ustc.edu.cn

数据结构Net

1、Net

  Net在caffe中代表一个完整的CNN模型,它包含若干个Layer的实例。Net中既包含了Layer的对象,也包含了Blob的对象。其中Blob对象用于存放每个Layer的输入/输出的结果,layer则根据Net的描述对指定的Blob进行某些计算和处理(卷积、下采样、全连接、非线性变换、计算代价函数等)输出的结果存储到指定的输出Blob中,同名的Blob表示同一个Blob对象,同名的Layer表示同一个Layer对象,而同名的Blob和Layer之间不存在任何的关系。

2、Net数据结构的描述

caffe.proto中Net数据结构:

message NetParameter {  optional string name = 1; //网络的名称  //网络的输入Blob的名称,可以有过个Blob  repeated string input = 3;  //输入Blob的维度信息  repeated BlobShape input_shape = 8;  //旧版本的维度信息  repeated int32 input_dim = 4;  // 网络是否强制每个层执行反向传播的计算,如果设置为false,是否执行反向传播计算由网络结构和学习速率自动决定  optional bool force_backward = 5 [default = false];  //网络当前的状态  optional NetState state = 6;  //运行Net::Forward,Net::BackWard,Net::Update时是否打印结果的调试信息  optional bool debug_info = 7 [default = false];  //组成Net的所有层,每个层配置都包括连接属性和行为,由LayerParameter定义  repeated LayerParameter layer = 100;  // ID 100 so layers are printed last.  //已淘汰  repeated V1LayerParameter layers = 2;}

3、Net模板类分析

3.1、Init函数

//用NetParameter对象初始化Nettemplate <typename Dtype>  Net<Dtype>::Net(const NetParameter& param) {    Init(param);  } 

3.2、前向传播的函数

// 前向传播,以下相关的前向传播函数,内部最终均会调用ForwardFromTo函数  const vector<Blob<Dtype>*>& ForwardPrefilled(Dtype* loss = NULL);  //Net前向传播的几种形式 Dtype ForwardFromTo(int start, int end);  Dtype ForwardFrom(int start);  Dtype ForwardTo(int end);  //输入指定的Blob进行前向传播,返回输出Blob  const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom, Dtype* loss = NULL);  //指定序列化的输入BlobProtoVector进行前向传播,返回序列化的输出BlobProtoVector  string Forward(const string& input_blob_protos, Dtype* loss = NULL); 

3.3、反向传播的函数

//几种不同形式的Net反向传播,无需指定输入/输出Blob,因为在前向传播中已经建立连接void Backward();  void BackwardFromTo(int start, int end);  void BackwardFrom(int start);  void BackwardTo(int end);  

3.4、前向&反向传播

// 前向反向传播 ,输入为Bottom Blob,输出为loss Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {      Dtype loss;      Forward(bottom, &loss);      Backward();      return loss;  }  

3.5、其他的一些函数

//根据已经由Solver准备好的diff值更新网络权值  void Update();  //共享权值和偏置数据  void ShareWeights();    // 从另一个Net拷贝train layers  void ShareTrainedLayersWith(const Net* other);    // 从另一个Net拷贝train layers,加载已训练好的模型  void CopyTrainedLayersFrom(const NetParameter& param);  void CopyTrainedLayersFrom(const string trained_filename);  void CopyTrainedLayersFromBinaryProto(const string trained_filename);  void CopyTrainedLayersFromHDF5(const string trained_filename);  // 写Net到NetParameter  void ToProto(NetParameter* param, bool write_diff = false) const;  // 写Net weights到HDF5文件  void ToHDF5(const string& filename, bool write_diff = false) const;  // 获得Net名  inline const string& name() const { return name_; }  // 获得所有layer名  inline const vector<string>& layer_names() const { return layer_names_; }  // 获得blob名  inline const vector<string>& blob_names() const { return blob_names_; }    // 获得blob  inline const vector<shared_ptr<Blob<Dtype> > >& blobs() const { return blobs_; }  // 获得layer  inline const vector<shared_ptr<Layer<Dtype> > >& layers() const { return layers_; }  // 获得Net phase状态:train or test  inline Phase phase() const { return phase_; }  // 获得每一个layer的bottom vector  inline const vector<vector<Blob<Dtype>*> >& bottom_vecs() const { return bottom_vecs_; }   // 获得每一个layer的top vector  inline const vector<vector<Blob<Dtype>*> >& top_vecs() const { return top_vecs_; }  inline const vector<vector<bool> >& bottom_need_backward() const { return bottom_need_backward_; }  inline const vector<Dtype>& blob_loss_weights() const { return blob_loss_weights_; }  inline const vector<bool>& layer_need_backward() const { return layer_need_backward_; }  // 获得各种参数值  inline const vector<shared_ptr<Blob<Dtype> > >& params() const { return params_; }  inline const vector<Blob<Dtype>*>& learnable_params() const { return learnable_params_; }  inline const vector<float>& params_lr() const { return params_lr_; }  inline const vector<bool>& has_params_lr() const { return has_params_lr_; }  inline const vector<float>& params_weight_decay() const { return params_weight_decay_; }  inline const vector<bool>& has_params_decay() const { return has_params_decay_; }  const map<string, int>& param_names_index() const { return param_names_index_; }  inline const vector<int>& param_owners() const { return param_owners_; }  // input blob数目  inline int num_inputs() const { return net_input_blobs_.size(); }  // output blob数目  inline int num_outputs() const { return net_output_blobs_.size(); }  inline const vector<Blob<Dtype>*>& input_blobs() const { return net_input_blobs_; }  inline const vector<Blob<Dtype>*>& output_blobs() const { return net_output_blobs_; }  inline const vector<int>& input_blob_indices() const { return net_input_blob_indices_; }  inline const vector<int>& output_blob_indices() const { return net_output_blob_indices_; }  bool has_blob(const string& blob_name) const;  const shared_ptr<Blob<Dtype> > blob_by_name(const string& blob_name) const;  bool has_layer(const string& layer_name) const;  const shared_ptr<Layer<Dtype> > layer_by_name(const string& layer_name) const;  // 设置是否显示debug info  void set_debug_info(const bool value) { debug_info_ = value; }   // 移除指定的layers  static void FilterNet(const NetParameter& param, NetParameter* param_filtered);  static bool StateMeetsRule(const NetState& state, const NetStateRule& rule, const string& layer_name);   protected:  // 追加top blob  void AppendTop(const NetParameter& param, const int layer_id,                   const int top_id, set<string>* available_blobs,                   map<string, int>* blob_name_to_idx);  // 追加bottom blob  int AppendBottom(const NetParameter& param, const int layer_id,                     const int bottom_id, set<string>* available_blobs,                     map<string, int>* blob_name_to_idx);  // 追加blob参数  void AppendParam(const NetParameter& param, const int layer_id, const int param_id);  // 显示debug info  /// @brief Helper for displaying debug info in Forward about input Blobs.  void InputDebugInfo(const int layer_id);  /// @brief Helper for displaying debug info in Forward.  void ForwardDebugInfo(const int layer_id);  /// @brief Helper for displaying debug info in Backward.  void BackwardDebugInfo(const int layer_id);  /// @brief Helper for displaying debug info in Update.  void UpdateDebugInfo(const int param_id);  // Caffe中类的成员变量名都带有后缀"_",这样就容易区分临时变量和类成员变量  /// @brief The network name  string name_; // Net名  /// @brief The phase: TRAIN or TEST  Phase phase_; // Net Phase状态:train or test  /// @brief Individual layers in the net  vector<shared_ptr<Layer<Dtype> > > layers_; // layers  vector<string> layer_names_; // layers名  map<string, int> layer_names_index_; // layers 索引  vector<bool> layer_need_backward_; // 指定layers是否需要backward  vector<shared_ptr<Blob<Dtype> > > blobs_; // 存储每一个layer产生的中间结果  vector<string> blob_names_; // blobs名  map<string, int> blob_names_index_; // blobs 索引  vector<bool> blob_need_backward_; // 指定blobs是否需要backward  vector<vector<Blob<Dtype>*> > bottom_vecs_; // 存储每一个layer input bottom blobs 指针  vector<vector<int> > bottom_id_vecs_; // 存储每一个bottom blobs id  vector<vector<bool> > bottom_need_backward_; // 指定bottom blobs是否需要backward  vector<vector<Blob<Dtype>*> > top_vecs_; // 存储每一个layer output top blobs 指针  vector<vector<int> > top_id_vecs_; // 存储每一个layer output top blobs id  vector<Dtype> blob_loss_weights_; // layer 的loss函数值  vector<vector<int> > param_id_vecs_; //   vector<int> param_owners_;  vector<string> param_display_names_;  vector<pair<int, int> > param_layer_indices_;  map<string, int> param_names_index_;  vector<int> net_input_blob_indices_;  vector<int> net_output_blob_indices_;  vector<Blob<Dtype>*> net_input_blobs_;  vector<Blob<Dtype>*> net_output_blobs_;  vector<shared_ptr<Blob<Dtype> > > params_; //   vector<Blob<Dtype>*> learnable_params_;  vector<int> learnable_param_ids_;  vector<float> params_lr_;  vector<bool> has_params_lr_;  vector<float> params_weight_decay_;  vector<bool> has_params_decay_;  size_t memory_used_;  bool debug_info_; // 是否显示debug info  const Net* const root_net_;  // 禁止使用Net类的拷贝和赋值操作  DISABLE_COPY_AND_ASSIGN(Net);  };