代码笔记:caffe-reid中reid_data_layer源码解析

来源:互联网 发布:域名放godaddy 编辑:程序博客网 时间:2024/06/06 12:39
#include <stdint.h>#include <cfloat>#include <vector>#include "caffe/data_transformer.hpp"#include "caffe/layers/reid_data_layer.hpp"#include "caffe/util/benchmark.hpp"#include <boost/thread.hpp>namespace caffe {template <typename Dtype>ReidDataLayer<Dtype>::~ReidDataLayer() {  this->StopInternalThread();}template <typename Dtype>unsigned int ReidDataLayer<Dtype>::RandRng() {  CHECK(prefetch_rng_);  caffe::rng_t *prefetch_rng =      static_cast<caffe::rng_t *>(prefetch_rng_->generator());  return (*prefetch_rng)();}(1)改写了ImageDataLayer中的DataLayerSetUP函数template <typename Dtype>void ReidDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,      const vector<Blob<Dtype>*>& top) {  DLOG(INFO) << "ReidDataLayer : DataLayerSetUp";  // Main Data Layer Set up  const int new_height = this->layer_param_.reid_data_param().new_height();  const int new_width  = this->layer_param_.reid_data_param().new_width();  const bool is_color  = this->layer_param_.reid_data_param().is_color();  CHECK((new_height == 0 && new_width == 0) ||      (new_height > 0 && new_width > 0)) << "Current implementation requires "      "new_height and new_width to be set at the same time.";  //读取图像文件和相应的label  // Read the file with filenames and labels  const string& source = this->layer_param_.reid_data_param().source();  LOG(INFO) << "Opening file " << source;  std::ifstream infile(source.c_str());  string line;  int mx_label = -1;  int mi_label = INT_MAX;  //按行读取,将行结果存为line  while (std::getline(infile, line)) {    size_t pos = line.find_last_of(' ');    int label = atoi(line.substr(pos + 1).c_str());    ///home/luoze/dataset/Market-1501-v15.09.15/bounding_box_train/0002_c1s1_000451_03.jpg 0    //以空格为分割点来分开line,前面为path,后面为标签    //vector<std::pair<std::string, int> > lines_;    this->lines_.push_back(std::make_pair(line.substr(0, pos), label));    mx_label = std::max(mx_label, label);    mi_label = std::min(mi_label, label);  }  //equal  CHECK_EQ(mi_label, 0);  this->label_set.clear();  //vector<vector<size_t> > label_set;  //mx_label = 750  this->label_set.resize(mx_label+1);  //lines_.size()是样本个数  for (size_t index = 0; index < this->lines_.size(); index++) {    int label = this->lines_[index].second;    this->label_set[label].push_back(index);  }  for (size_t index = 0; index < this->label_set.size(); index++) {    CHECK_GT(this->label_set[index].size(), 0) << "label : " << index << " has no images";  }  CHECK(!lines_.empty()) << "File is empty";  infile.close();  LOG(INFO) << "A total of " << lines_.size() << " images. Label : [" << mi_label << ", " << mx_label << "]";  LOG(INFO) << "A total of " << label_set.size() << " persons";  //随机因子  const unsigned int prefetch_rng_seed = caffe_rng_rand();  prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));  //初始化为图片的总数  this->left_images = this->lines_.size();  this->pos_fraction = this->layer_param_.reid_data_param().pos_fraction();  this->neg_fraction = this->layer_param_.reid_data_param().pos_fraction();  CHECK_GT(lines_.size(), 0);  //开始根据path来取图了  //vector<cv::Mat> cv_imgs_;  this->cv_imgs_.clear();  for (size_t lines_id_ = 0; lines_id_ < this->lines_.size(); lines_id_++) {    cv::Mat cv_img = ReadImageToCVMat(lines_[lines_id_].first, new_height, new_width, is_color);    CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;    this->cv_imgs_.push_back(cv_img);  }  //使用Opencv来读进图像,然后使用它初始化相应的top blob  // Read an image, and use it to initialize the top blob.  cv::Mat cv_img = ReadImageToCVMat(lines_[0].first,                                    new_height, new_width, is_color);  CHECK(cv_img.data) << "Could not load " << lines_[0].first;  const int batch_size = this->layer_param_.reid_data_param().batch_size();  // Use data_transformer to infer the expected blob shape from datum.  //top_shape 输出的形状  //使用data_transformer 来计算根据datum的期望blob的shape  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  vector<int> prefetch_top_shape = top_shape;  this->transformed_data_.Reshape(top_shape);  //首先reshape top[0],再根据batch的大小进行预取  // Reshape top[0] and prefetch_data according to the batch_size.  top_shape[0] = batch_size * 2;  prefetch_top_shape[0] = batch_size;  top[0]->Reshape(top_shape);  //top[1]->Reshape(top_shape);  for (int i = 0; i < this->PREFETCH_COUNT; ++i) {    //同时预取了两组数据    this->prefetch_[i].data_.Reshape(prefetch_top_shape);    this->prefetch_[i].datap_.Reshape(prefetch_top_shape);  }  //256 3 277 277  LOG(INFO) << "output data size: " << top[0]->num() << ","      << top[0]->channels() << "," << top[0]->height() << ","      << top[0]->width();  //LOG(INFO) << "output data pair size: " << top[1]->num() << ","  //    << top[1]->channels() << "," << top[1]->height() << ","  //    << top[1]->width();  // label  if (this->output_labels_) {    vector<int> label_shape(1, batch_size*2);    top[1]->Reshape(label_shape);    vector<int> prefetch_label_shape(1, batch_size);    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {      //同时预取了3组数据      this->prefetch_[i].label_.Reshape(prefetch_label_shape);      this->prefetch_[i].labelp_.Reshape(prefetch_label_shape);      this->prefetch_[i].labeldif_.Reshape(prefetch_label_shape);    }    //256(256)    LOG(INFO) << "output label size : " << top[1]->shape_string();  }}(2// This function is called on prefetch threadtemplate<typename Dtype>void ReidDataLayer<Dtype>::load_batch(ReidBatch  <Dtype>* batch) {  CPUTimer batch_timer;  batch_timer.Start();  double read_time = 0;  double trans_time = 0;  CPUTimer timer;  CHECK(batch->data_.count());  CHECK(this->transformed_data_.count());  const int batch_size = this->layer_param_.reid_data_param().batch_size();  //完全随机的取值吗?  //一组长度为batch_size的图像ID  const vector<size_t> batches = this->batch_ids();  //一组与batches对应类别的的长度为batch_size的图像ID  const vector<size_t> batches_pair = this->batch_pairs(batches);  CHECK_EQ(batches.size(), batch_size);  CHECK_EQ(batches_pair.size(), batch_size);  // Reshape according to the first image of each batch  // on single input batches allows for inputs of varying dimension.  cv::Mat cv_img = this->cv_imgs_[batches[0]];  CHECK(cv_img.data) << "Could not load " << this->lines_[batches[0]].first;  // Use data_transformer to infer the expected blob shape from a cv_img.  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  this->transformed_data_.Reshape(top_shape);  // Reshape batch according to the batch_size.  top_shape[0] = batch_size;  //ReidBatch<Dtype>* batch  batch->data_.Reshape(top_shape);  batch->datap_.Reshape(top_shape);  Dtype* prefetch_data = batch->data_.mutable_cpu_data();  Dtype* prefetch_datap = batch->datap_.mutable_cpu_data();  Dtype* prefetch_label = batch->label_.mutable_cpu_data();  Dtype* prefetch_labelp = batch->labelp_.mutable_cpu_data();  Dtype* prefetch_labeldif = batch->labeldif_.mutable_cpu_data();  for (int item_id = 0; item_id < batch_size; ++item_id) {    // get a blob    timer.Start();    //两张图ID    const size_t true_idx = batches[item_id];    const size_t pair_idx = batches_pair[item_id];    //两张图    cv::Mat cv_img_true = this->cv_imgs_[ true_idx ];    cv::Mat cv_img_pair = this->cv_imgs_[ pair_idx ];    CHECK(cv_img_true.data) << "Could not load " << this->lines_[true_idx].first;    CHECK(cv_img_pair.data) << "Could not load " << this->lines_[pair_idx].first;    read_time += timer.MicroSeconds();    timer.Start();    // Apply transformations (mirror, crop...) to the image    const int t_offset = batch->data_.offset(item_id);    this->transformed_data_.set_cpu_data(prefetch_data + t_offset);    this->data_transformer_->Transform(cv_img_true, &(this->transformed_data_));    // Pair Data    const int p_offset = batch->datap_.offset(item_id);    this->transformed_data_.set_cpu_data(prefetch_datap + p_offset);    this->data_transformer_->Transform(cv_img_pair, &(this->transformed_data_));    trans_time += timer.MicroSeconds();    CHECK_GE(lines_[true_idx].second, 0);    CHECK_GE(lines_[pair_idx].second, 0);    CHECK_LT(lines_[true_idx].second, this->label_set.size());    CHECK_LT(lines_[pair_idx].second, this->label_set.size());    prefetch_label[item_id]    = lines_[true_idx].second;    prefetch_labelp[item_id]   = lines_[pair_idx].second;    //labeldif变成后文最大的悬念之一    prefetch_labeldif[item_id] = lines_[true_idx].second == lines_[pair_idx].second;    DLOG(INFO) << "Idx : " << item_id << " : " << lines_[true_idx].second << " vs " << lines_[pair_idx].second << " ..=.. " << prefetch_labeldif[item_id];  }  batch_timer.Stop();  DLOG(INFO) << "Pair Idx : (" << batches[0] << "," << batches_pair[0] << ")";  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";}INSTANTIATE_CLASS(ReidDataLayer);REGISTER_LAYER_CLASS(ReidData);}  // namespace caffe
0 0
原创粉丝点击