SSD:Single Shot MultiBox Detector代码解读(二)

来源:互联网 发布:世界贸易组织数据库 编辑:程序博客网 时间:2024/06/03 11:51

SSD:Single Shot MultiBox Detector代码解读(一): http://blog.csdn.net/u011956147/article/details/73028773
SSD:Single Shot MultiBox Detector代码解读(二): http://blog.csdn.net/u011956147/article/details/73030116
SSD:Single Shot MultiBox Detector代码解读(三): http://blog.csdn.net/u011956147/article/details/73032867
SSD:Single Shot MultiBox Detector代码解读(四): http://blog.csdn.net/u011956147/article/details/73033170
SSD:Single Shot MultiBox Detector代码解读(五): http://blog.csdn.net/u011956147/article/details/73033282


看网上的说法,SSD代码有过更新,在这里,我采用的是最新版本的(2017/6/7)
主要粗略的分析下caffe版本的SSD代码,还有很多细节没有去仔细研究,希望辩证的看待,如果有什么问题和不同的见解可以提出来,大家一起进步,当然SSD还有比如TensorFlow版本的,网上也有教程,这里就不在说明了。

原作者的代码分散在很多地方,主要是include/caffe/layer,src/caffe/layer和src/caffe/utlis/目录下面。包括annotated_data_layer.hpp、detection_evaluate_layer.hpp、detection_output_layer.hpp、multibox_loss_layer.hpp、prior_box_layer.hpp和与之对应的.cpp文件和.cu文件,这里也只是分析.cpp文件。

跟训练有关的是annotated_data_layer、multibox_loss_layer、prior_box_layer以及 bbox_util。detection_evaluate_layer是验证模型效果用的、detection_output_layer是输出检测结果用的,在后续再继续补充。

首先,annotated_data_layer:主要就是把图片读进来,同时做一些数据增广(data augmentation),同时把gt box取出来保存。
代码如下(其中有相应注释):

#ifdef USE_OPENCV#include <opencv2/core/core.hpp>#endif  // USE_OPENCV#include <stdint.h>#include <algorithm>#include <map>#include <vector>#include "caffe/data_transformer.hpp"#include "caffe/layers/annotated_data_layer.hpp"#include "caffe/util/benchmark.hpp"#include "caffe/util/sampler.hpp"namespace caffe {template <typename Dtype>       // 构造函数AnnotatedDataLayer<Dtype>::AnnotatedDataLayer(const LayerParameter& param)  : BasePrefetchingDataLayer<Dtype>(param),    reader_(param) {}template <typename Dtype>        // 析构AnnotatedDataLayer<Dtype>::~AnnotatedDataLayer() {  this->StopInternalThread();}template <typename Dtype>        // 该函数主要是参数设置void AnnotatedDataLayer<Dtype>::DataLayerSetUp(    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {  const int batch_size = this->layer_param_.data_param().batch_size();  //设置batch_size,一般默认是32,在网络结构的代码中修改  const AnnotatedDataParameter& anno_data_param =      this->layer_param_.annotated_data_param();                        // 为数据增广设置的采样batch  for (int i = 0; i < anno_data_param.batch_sampler_size(); ++i) {      // 依次使用第i种采样方式获取图片的样本数据,存入batch_samplers中    batch_samplers_.push_back(anno_data_param.batch_sampler(i));          }  label_map_file_ = anno_data_param.label_map_file();                   // 每一类的物体分类标签的文件  // Make sure dimension is consistent within batch.  const TransformationParameter& transform_param =                      // 待考虑。。。    this->layer_param_.transform_param();  if (transform_param.has_resize_param()) {    if (transform_param.resize_param().resize_mode() ==        ResizeParameter_Resize_mode_FIT_SMALL_SIZE) {      CHECK_EQ(batch_size, 1)        << "Only support batch size of 1 for FIT_SMALL_SIZE.";    }  }  // Read a data point, and use it to initialize the top blob.  AnnotatedDatum& anno_datum = *(reader_.full().peek());               // 待考虑。。。  // Use data_transformer to infer the expected blob shape from anno_datum.  vector<int> top_shape =                               this->data_transformer_->InferBlobShape(anno_datum.datum());    // 利用transform_param确定top shape  this->transformed_data_.Reshape(top_shape);  // Reshape top[0] and prefetch_data according to the batch_size.  top_shape[0] = batch_size;                                          // N = batch_size  top[0]->Reshape(top_shape);                                         // 根据刚才确定的值设定top的大小   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {                    //     this->prefetch_[i].data_.Reshape(top_shape);                      // 根据top shape把相应大小的batch提前拿出  }  LOG(INFO) << "output data size: " << top[0]->num() << ","           // 打印结果      << top[0]->channels() << "," << top[0]->height() << ","      << top[0]->width();  // label                                               if (this->output_labels_) {                                         // 读取标签文件    has_anno_type_ = anno_datum.has_type() || anno_data_param.has_anno_type();    vector<int> label_shape(4, 1);    if (has_anno_type_) {      anno_type_ = anno_datum.type();      if (anno_data_param.has_anno_type()) {        // If anno_type is provided in AnnotatedDataParameter, replace        // the type stored in each individual AnnotatedDatum.        LOG(WARNING) << "type stored in AnnotatedDatum is shadowed.";        anno_type_ = anno_data_param.anno_type();      }      // Infer the label shape from anno_datum.AnnotationGroup().      int num_bboxes = 0;      if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {        // Since the number of bboxes can be different for each image,        // we store the bbox information in a specific format. In specific:        // All bboxes are stored in one spatial plane (num and channels are 1)        // And each row contains one and only one box in the following format:        // [item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]        // Note: Refer to caffe.proto for details about group_label and        // instance_id.        for (int g = 0; g < anno_datum.annotation_group_size(); ++g) {          num_bboxes += anno_datum.annotation_group(g).annotation_size();        }        label_shape[0] = 1;        label_shape[1] = 1;        // BasePrefetchingDataLayer<Dtype>::LayerSetUp() requires to call        // cpu_data and gpu_data for consistent prefetch thread. Thus we make        // sure there is at least one bbox.        label_shape[2] = std::max(num_bboxes, 1);        label_shape[3] = 8;      } else {        LOG(FATAL) << "Unknown annotation type.";      }    } else {      label_shape[0] = batch_size;    }    top[1]->Reshape(label_shape);    for (int i = 0; i < this->PREFETCH_COUNT; ++i) {      this->prefetch_[i].label_.Reshape(label_shape);    }  }}// This function is called on prefetch threadtemplate<typename Dtype>                       // 提前从内存中取出数据void AnnotatedDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  CPUTimer batch_timer;  batch_timer.Start();  double read_time = 0;  double trans_time = 0;  CPUTimer timer;  CHECK(batch->data_.count());  CHECK(this->transformed_data_.count());  // Reshape according to the first anno_datum of each batch  // on single input batches allows for inputs of varying dimension.  const int batch_size = this->layer_param_.data_param().batch_size();  const AnnotatedDataParameter& anno_data_param =      this->layer_param_.annotated_data_param();  const TransformationParameter& transform_param =    this->layer_param_.transform_param();  AnnotatedDatum& anno_datum = *(reader_.full().peek());  // Use data_transformer to infer the expected blob shape from anno_datum.  vector<int> top_shape =      this->data_transformer_->InferBlobShape(anno_datum.datum());  this->transformed_data_.Reshape(top_shape);  // Reshape batch according to the batch_size.  top_shape[0] = batch_size;  batch->data_.Reshape(top_shape);  Dtype* top_data = batch->data_.mutable_cpu_data();  Dtype* top_label = NULL;  // suppress warnings about uninitialized variables  if (this->output_labels_ && !has_anno_type_) {    top_label = batch->label_.mutable_cpu_data();  }  // Store transformed annotation.  map<int, vector<AnnotationGroup> > all_anno;  int num_bboxes = 0;  for (int item_id = 0; item_id < batch_size; ++item_id) {    timer.Start();    // get a anno_datum    AnnotatedDatum& anno_datum = *(reader_.full().pop("Waiting for data"));    read_time += timer.MicroSeconds();    timer.Start();    AnnotatedDatum distort_datum;    AnnotatedDatum* expand_datum = NULL;    if (transform_param.has_distort_param()) {      distort_datum.CopyFrom(anno_datum);      this->data_transformer_->DistortImage(anno_datum.datum(),                                            distort_datum.mutable_datum());      if (transform_param.has_expand_param()) {        expand_datum = new AnnotatedDatum();        this->data_transformer_->ExpandImage(distort_datum, expand_datum);      } else {        expand_datum = &distort_datum;      }    } else {      if (transform_param.has_expand_param()) {        expand_datum = new AnnotatedDatum();        this->data_transformer_->ExpandImage(anno_datum, expand_datum);      } else {        expand_datum = &anno_datum;      }    }    AnnotatedDatum* sampled_datum = NULL;    bool has_sampled = false;    if (batch_samplers_.size() > 0) {      // Generate sampled bboxes from expand_datum.      vector<NormalizedBBox> sampled_bboxes;      GenerateBatchSamples(*expand_datum, batch_samplers_, &sampled_bboxes);      if (sampled_bboxes.size() > 0) {        // Randomly pick a sampled bbox and crop the expand_datum.        int rand_idx = caffe_rng_rand() % sampled_bboxes.size();        sampled_datum = new AnnotatedDatum();        this->data_transformer_->CropImage(*expand_datum,                                           sampled_bboxes[rand_idx],                                           sampled_datum);        has_sampled = true;      } else {        sampled_datum = expand_datum;      }    } else {      sampled_datum = expand_datum;    }    CHECK(sampled_datum != NULL);    timer.Start();    vector<int> shape =        this->data_transformer_->InferBlobShape(sampled_datum->datum());    if (transform_param.has_resize_param()) {      if (transform_param.resize_param().resize_mode() ==          ResizeParameter_Resize_mode_FIT_SMALL_SIZE) {        this->transformed_data_.Reshape(shape);        batch->data_.Reshape(shape);        top_data = batch->data_.mutable_cpu_data();      } else {        CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,              shape.begin() + 1));      }    } else {      CHECK(std::equal(top_shape.begin() + 1, top_shape.begin() + 4,            shape.begin() + 1));    }    // Apply data transformations (mirror, scale, crop...)    int offset = batch->data_.offset(item_id);    this->transformed_data_.set_cpu_data(top_data + offset);    vector<AnnotationGroup> transformed_anno_vec;    if (this->output_labels_) {      if (has_anno_type_) {        // Make sure all data have same annotation type.        CHECK(sampled_datum->has_type()) << "Some datum misses AnnotationType.";        if (anno_data_param.has_anno_type()) {          sampled_datum->set_type(anno_type_);        } else {          CHECK_EQ(anno_type_, sampled_datum->type()) <<              "Different AnnotationType.";        }        // Transform datum and annotation_group at the same time        transformed_anno_vec.clear();        this->data_transformer_->Transform(*sampled_datum,                                           &(this->transformed_data_),                                           &transformed_anno_vec);        if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {          // Count the number of bboxes.          for (int g = 0; g < transformed_anno_vec.size(); ++g) {            num_bboxes += transformed_anno_vec[g].annotation_size();          }        } else {          LOG(FATAL) << "Unknown annotation type.";        }        all_anno[item_id] = transformed_anno_vec;      } else {        this->data_transformer_->Transform(sampled_datum->datum(),                                           &(this->transformed_data_));        // Otherwise, store the label from datum.        CHECK(sampled_datum->datum().has_label()) << "Cannot find any label.";        top_label[item_id] = sampled_datum->datum().label();      }    } else {      this->data_transformer_->Transform(sampled_datum->datum(),                                         &(this->transformed_data_));    }    // clear memory    if (has_sampled) {      delete sampled_datum;    }    if (transform_param.has_expand_param()) {      delete expand_datum;    }    trans_time += timer.MicroSeconds();    reader_.free().push(const_cast<AnnotatedDatum*>(&anno_datum));  }  // Store "rich" annotation if needed.  if (this->output_labels_ && has_anno_type_) {    vector<int> label_shape(4);    if (anno_type_ == AnnotatedDatum_AnnotationType_BBOX) {      label_shape[0] = 1;      label_shape[1] = 1;      label_shape[3] = 8;      if (num_bboxes == 0) {        // Store all -1 in the label.        label_shape[2] = 1;        batch->label_.Reshape(label_shape);        caffe_set<Dtype>(8, -1, batch->label_.mutable_cpu_data());      } else {        // Reshape the label and store the annotation.        label_shape[2] = num_bboxes;        batch->label_.Reshape(label_shape);        top_label = batch->label_.mutable_cpu_data();        int idx = 0;        for (int item_id = 0; item_id < batch_size; ++item_id) {          const vector<AnnotationGroup>& anno_vec = all_anno[item_id];          for (int g = 0; g < anno_vec.size(); ++g) {            const AnnotationGroup& anno_group = anno_vec[g];            for (int a = 0; a < anno_group.annotation_size(); ++a) {              const Annotation& anno = anno_group.annotation(a);              const NormalizedBBox& bbox = anno.bbox();              top_label[idx++] = item_id;              top_label[idx++] = anno_group.group_label();              top_label[idx++] = anno.instance_id();              top_label[idx++] = bbox.xmin();              top_label[idx++] = bbox.ymin();              top_label[idx++] = bbox.xmax();              top_label[idx++] = bbox.ymax();              top_label[idx++] = bbox.difficult();            }          }        }      }    } else {      LOG(FATAL) << "Unknown annotation type.";    }  }  timer.Stop();  batch_timer.Stop();  DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";}INSTANTIATE_CLASS(AnnotatedDataLayer);REGISTER_LAYER_CLASS(AnnotatedData);}  // namespace caffe

本文链接:http://blog.csdn.net/u011956147/article/details/73030116

阅读全文
0 1