梳理caffe代码data_reader(十一)

来源:互联网 发布:适马调焦器 知乎 编辑:程序博客网 时间:2024/04/29 04:43

上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。

顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。

我们就看看数据是怎么进行排队的?

头文件:

#ifndef CAFFE_DATA_READER_HPP_#define CAFFE_DATA_READER_HPP_#include <map>#include <string>#include <vector>#include "caffe/common.hpp"#include "caffe/internal_thread.hpp"#include "caffe/util/blocking_queue.hpp"#include "caffe/util/db.hpp"namespace caffe {/** * @brief Reads data from a source to queues available to data layers. * A single reading thread is created per source, even if multiple solvers * are running in parallel, e.g. for multi-GPU training. This makes sure * databases are read sequentially, and that each solver accesses a different * subset of the database. Data is distributed to solvers in a round-robin * way to keep parallel training deterministic. *//*从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。*/class DataReader { public:  explicit DataReader(const LayerParameter& param);  ~DataReader();//  inline BlockingQueue<Datum*>& free() const {    return queue_pair_->free_;  }  inline BlockingQueue<Datum*>& full() const {    return queue_pair_->full_;  } protected:  // Queue pairs are shared between a body and its readers  class QueuePair {   public:    explicit QueuePair(int size);    ~QueuePair();//定义了两个阻塞队列free_和full_    BlockingQueue<Datum*> free_;    BlockingQueue<Datum*> full_;  DISABLE_COPY_AND_ASSIGN(QueuePair);  };  // A single body is created per source//继承InternalThread 这个类的  class Body : public InternalThread {   public:    explicit Body(const LayerParameter& param);    virtual ~Body();   protected://重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数    void InternalThreadEntry();    void read_one(db::Cursor* cursor, QueuePair* qp);    const LayerParameter param_;    BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;//内部有DataReader的友元    friend class DataReader;  DISABLE_COPY_AND_ASSIGN(Body);  };  // A source is uniquely identified by its layer name + path, in case  // the same database is read from two different locations in the net.  static inline string source_key(const LayerParameter& param) {    return param.name() + ":" + param.data_param().source();  }  const shared_ptr<QueuePair> queue_pair_;  shared_ptr<Body> body_;  static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;DISABLE_COPY_AND_ASSIGN(DataReader);};}  // namespace caffe#endif  // CAFFE_DATA_READER_HPP_
实现部分:
#include <boost/thread.hpp>#include <map>#include <string>#include <vector>#include "caffe/common.hpp"#include "caffe/data_reader.hpp"#include "caffe/layers/data_layer.hpp"#include "caffe/proto/caffe.pb.h"namespace caffe {using boost::weak_ptr;map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;static boost::mutex bodies_mutex_;DataReader::DataReader(const LayerParameter& param)    : queue_pair_(new QueuePair(  //        param.data_param().prefetch() * param.data_param().batch_size())) {  // Get or create a body  boost::mutex::scoped_lock lock(bodies_mutex_);  string key = source_key(param);  weak_ptr<Body>& weak = bodies_[key];  body_ = weak.lock();  if (!body_) {    body_.reset(new Body(param));    bodies_[key] = weak_ptr<Body>(body_);  }  body_->new_queue_pairs_.push(queue_pair_);}DataReader::~DataReader() {  string key = source_key(body_->param_);  body_.reset();  boost::mutex::scoped_lock lock(bodies_mutex_);  if (bodies_[key].expired()) {    bodies_.erase(key);  }}//根据给定的size初始化的若干个Datum的实例到free里面DataReader::QueuePair::QueuePair(int size) {  // Initialize the free queue with requested number of datums  for (int i = 0; i < size; ++i) {    free_.push(new Datum());  }}//将full_和free_这两个队列里面的Datum对象全部delete。DataReader::QueuePair::~QueuePair() {  Datum* datum;  while (free_.try_pop(&datum)) {    delete datum;  }  while (full_.try_pop(&datum)) {    delete datum;  }}//Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程DataReader::Body::Body(const LayerParameter& param)    : param_(param),      new_queue_pairs_() {  StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容}// 析构,停止线程DataReader::Body::~Body() {  StopInternalThread();}// 自己实现的需要执行的函数// 首先打开数据库,然后设置游标,然后设置QueuePair指针容器void DataReader::Body::InternalThreadEntry() {  // 获取所给定的数据源的类型来得到DB的指针  shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend()));  // 从网络参数中给定的DB的位置打开DB  db->Open(param_.data_param().source(), db::READ);  // 新建游标指针  shared_ptr<db::Cursor> cursor(db->NewCursor());  // 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列  vector<shared_ptr<QueuePair> > qps;  try {    // 根据网络参数的阶段来设置solver_count    int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1;    // To ensure deterministic runs, only start running once all solvers    // are ready. But solvers need to peek on one item during initialization,    // so read one item, then wait for the next solver.    for (int i = 0; i < solver_count; ++i) {      shared_ptr<QueuePair> qp(new_queue_pairs_.pop());      read_one(cursor.get(), qp.get());// 读取一个数据      qps.push_back(qp);压入    }    // Main loop    while (!must_stop()) {      for (int i = 0; i < solver_count; ++i) {        read_one(cursor.get(), qps[i].get());      }      // Check no additional readers have been created. This can happen if      // more than one net is trained at a time per process, whether single      // or multi solver. It might also happen if two data layers have same      // name and same source.      CHECK_EQ(new_queue_pairs_.size(), 0);    }  } catch (boost::thread_interrupted&) {    // Interrupted exception is expected on shutdown  }}// 从数据库中获取一个数据void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) {  // 从QueuePair中的free_队列pop出一个  Datum* datum = qp->free_.pop();  // TODO deserialize in-place instead of copy?  // 然后解析cursor中的值  datum->ParseFromString(cursor->value());  // 然后压入QueuePair中的full_队列  qp->full_.push(datum);  // go to the next iter  // 游标指向下一个  cursor->Next();  if (!cursor->valid()) {    DLOG(INFO) << "Restarting data prefetching from start.";    cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置  }}}  // namespace caffe
数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。

0 0
原创粉丝点击