caffe BasePrefetchingDataLayer 学习

来源:互联网 发布:大学生兼职数据调查 编辑:程序博客网 时间:2024/06/05 03:03

BasePrefetchingDataLayer

介绍:
这个层通过继承BaseDataLayer和InternalThread这两个基本类,实现了数据(图片及标签)的获取。

template <typename Dtype>class BasePrefetchingDataLayer :    public BaseDataLayer<Dtype>, public InternalThread { public: //通过调用BaseDataLayer的构造函数进行初始化。  explicit BasePrefetchingDataLayer(const LayerParameter& param)      : BaseDataLayer<Dtype>(param) {}  // LayerSetUp: implements common data layer setup functionality, and calls,功能就是就行一般的设置  // DataLayerSetUp to do special data layer setup for individual layer types.具体层的设置  // This method may not be overridden.  void LayerSetUp(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top);//后面的层不需要写这个函数覆盖,只需要改DataLayerSetUp即可  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top);  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top);  virtual void CreatePrefetchThread();//创建获取线程  virtual void JoinPrefetchThread();//获取数据的函数  // The thread's function  JoinPrefetchThread就是调用这个具体的执行函数,这个函数根据具体的数据层类型再来实现这个函数。  virtual void InternalThreadEntry() {} protected:  Blob<Dtype> prefetch_data_;//存储data  Blob<Dtype> prefetch_label_;//存储label  Blob<Dtype> transformed_data_;//存储转化后的data};
template <typename Dtype>void BasePrefetchingDataLayer<Dtype>::LayerSetUp(    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  BaseDataLayer<Dtype>::LayerSetUp(bottom, top);  // Now, start the prefetch thread. Before calling prefetch, we make two  // cpu_data calls so that the prefetch thread does not accidentally make  // simultaneous cudaMalloc calls when the main thread is running. In some  // GPUs this seems to cause failures if we do not so.  this->prefetch_data_.mutable_cpu_data();  if (this->output_labels_) {    this->prefetch_label_.mutable_cpu_data();  }  DLOG(INFO) << "Initializing prefetch";  this->CreatePrefetchThread();  DLOG(INFO) << "Prefetch initialized.";}template <typename Dtype>void BasePrefetchingDataLayer<Dtype>::CreatePrefetchThread() {  this->data_transformer_->InitRand();  CHECK(StartInternalThread()) << "Thread execution failed";}template <typename Dtype>void BasePrefetchingDataLayer<Dtype>::JoinPrefetchThread() {  CHECK(WaitForInternalThreadToExit()) << "Thread joining failed";}template <typename Dtype>void BasePrefetchingDataLayer<Dtype>::Forward_cpu(    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  // First, join the thread  JoinPrefetchThread();//调用的是InternalThreadEntry这个函数,实现了对prefetch_data_ and prefetch_label_的赋值。  DLOG(INFO) << "Thread joined";  // Reshape to loaded data.  top[0]->ReshapeLike(prefetch_data_);  // Copy the data  caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(),             top[0]->mutable_cpu_data());  DLOG(INFO) << "Prefetch copied";  if (this->output_labels_) {    // Reshape to loaded labels.    top[1]->ReshapeLike(prefetch_label_);    // Copy the labels.    caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(),               top[1]->mutable_cpu_data());  }  // Start a new prefetch thread  DLOG(INFO) << "CreatePrefetchThread";  CreatePrefetchThread();}
0 0