梳理caffe代码data_reader(十一)
来源:互联网 发布:适马调焦器 知乎 编辑:程序博客网 时间:2024/04/29 04:43
上一篇的blocking_queue到底干了一件什么事情呢?刚刚看完就有点忘记了,再过一会估计忘光了。。。
顾名思义,阻塞队列,就是一个正在排队的打饭队列,先到窗口的先打饭,为什么会高效安全呢?一是像交通有秩序,二是有了秩序是不是交通运行起来就快了。
我们就看看数据是怎么进行排队的?
头文件:
#ifndef CAFFE_DATA_READER_HPP_#define CAFFE_DATA_READER_HPP_#include <map>#include <string>#include <vector>#include "caffe/common.hpp"#include "caffe/internal_thread.hpp"#include "caffe/util/blocking_queue.hpp"#include "caffe/util/db.hpp"namespace caffe {/** * @brief Reads data from a source to queues available to data layers. * A single reading thread is created per source, even if multiple solvers * are running in parallel, e.g. for multi-GPU training. This makes sure * databases are read sequentially, and that each solver accesses a different * subset of the database. Data is distributed to solvers in a round-robin * way to keep parallel training deterministic. *//*从共享的资源读取数据然后排队输入到数据层,每个资源创建单个线程,即便是使用多个GPU在并行任务中求解。这就保证对于频繁读取数据库,并且每个求解的线程使用的子数据是不同的。数据成功设计就是这样使在求解时数据保持一种循环地并行训练。*/class DataReader { public: explicit DataReader(const LayerParameter& param); ~DataReader();// inline BlockingQueue<Datum*>& free() const { return queue_pair_->free_; } inline BlockingQueue<Datum*>& full() const { return queue_pair_->full_; } protected: // Queue pairs are shared between a body and its readers class QueuePair { public: explicit QueuePair(int size); ~QueuePair();//定义了两个阻塞队列free_和full_ BlockingQueue<Datum*> free_; BlockingQueue<Datum*> full_; DISABLE_COPY_AND_ASSIGN(QueuePair); }; // A single body is created per source//继承InternalThread 这个类的 class Body : public InternalThread { public: explicit Body(const LayerParameter& param); virtual ~Body(); protected://重写了InternalThread内部的InternalThreadEntry函数,此外还添加了read_one函数 void InternalThreadEntry(); void read_one(db::Cursor* cursor, QueuePair* qp); const LayerParameter param_; BlockingQueue<shared_ptr<QueuePair> > new_queue_pairs_;//内部有DataReader的友元 friend class DataReader; DISABLE_COPY_AND_ASSIGN(Body); }; // A source is uniquely identified by its layer name + path, in case // the same database is read from two different locations in the net. static inline string source_key(const LayerParameter& param) { return param.name() + ":" + param.data_param().source(); } const shared_ptr<QueuePair> queue_pair_; shared_ptr<Body> body_; static map<const string, boost::weak_ptr<DataReader::Body> > bodies_;DISABLE_COPY_AND_ASSIGN(DataReader);};} // namespace caffe#endif // CAFFE_DATA_READER_HPP_实现部分:
#include <boost/thread.hpp>#include <map>#include <string>#include <vector>#include "caffe/common.hpp"#include "caffe/data_reader.hpp"#include "caffe/layers/data_layer.hpp"#include "caffe/proto/caffe.pb.h"namespace caffe {using boost::weak_ptr;map<const string, weak_ptr<DataReader::Body> > DataReader::bodies_;static boost::mutex bodies_mutex_;DataReader::DataReader(const LayerParameter& param) : queue_pair_(new QueuePair( // param.data_param().prefetch() * param.data_param().batch_size())) { // Get or create a body boost::mutex::scoped_lock lock(bodies_mutex_); string key = source_key(param); weak_ptr<Body>& weak = bodies_[key]; body_ = weak.lock(); if (!body_) { body_.reset(new Body(param)); bodies_[key] = weak_ptr<Body>(body_); } body_->new_queue_pairs_.push(queue_pair_);}DataReader::~DataReader() { string key = source_key(body_->param_); body_.reset(); boost::mutex::scoped_lock lock(bodies_mutex_); if (bodies_[key].expired()) { bodies_.erase(key); }}//根据给定的size初始化的若干个Datum的实例到free里面DataReader::QueuePair::QueuePair(int size) { // Initialize the free queue with requested number of datums for (int i = 0; i < size; ++i) { free_.push(new Datum()); }}//将full_和free_这两个队列里面的Datum对象全部delete。DataReader::QueuePair::~QueuePair() { Datum* datum; while (free_.try_pop(&datum)) { delete datum; } while (full_.try_pop(&datum)) { delete datum; }}//Body类的构造函数,实际上是给定网络的参数,然后开始启动内部线程DataReader::Body::Body(const LayerParameter& param) : param_(param), new_queue_pairs_() { StartInternalThread();// 调用InternalThread内部的函数来初始化运行环境以及新建线程去执行虚函数InternalThreadEntry的内容}// 析构,停止线程DataReader::Body::~Body() { StopInternalThread();}// 自己实现的需要执行的函数// 首先打开数据库,然后设置游标,然后设置QueuePair指针容器void DataReader::Body::InternalThreadEntry() { // 获取所给定的数据源的类型来得到DB的指针 shared_ptr<db::DB> db(db::GetDB(param_.data_param().backend())); // 从网络参数中给定的DB的位置打开DB db->Open(param_.data_param().source(), db::READ); // 新建游标指针 shared_ptr<db::Cursor> cursor(db->NewCursor()); // 新建QueuePair指针容器,QueuePair里面包含了free_和full_这两个阻塞队列 vector<shared_ptr<QueuePair> > qps; try { // 根据网络参数的阶段来设置solver_count int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; // To ensure deterministic runs, only start running once all solvers // are ready. But solvers need to peek on one item during initialization, // so read one item, then wait for the next solver. for (int i = 0; i < solver_count; ++i) { shared_ptr<QueuePair> qp(new_queue_pairs_.pop()); read_one(cursor.get(), qp.get());// 读取一个数据 qps.push_back(qp);压入 } // Main loop while (!must_stop()) { for (int i = 0; i < solver_count; ++i) { read_one(cursor.get(), qps[i].get()); } // Check no additional readers have been created. This can happen if // more than one net is trained at a time per process, whether single // or multi solver. It might also happen if two data layers have same // name and same source. CHECK_EQ(new_queue_pairs_.size(), 0); } } catch (boost::thread_interrupted&) { // Interrupted exception is expected on shutdown }}// 从数据库中获取一个数据void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { // 从QueuePair中的free_队列pop出一个 Datum* datum = qp->free_.pop(); // TODO deserialize in-place instead of copy? // 然后解析cursor中的值 datum->ParseFromString(cursor->value()); // 然后压入QueuePair中的full_队列 qp->full_.push(datum); // go to the next iter // 游标指向下一个 cursor->Next(); if (!cursor->valid()) { DLOG(INFO) << "Restarting data prefetching from start."; cursor->SeekToFirst();// 如果游标指向的位置已经无效了则指向第一个位置 }}} // namespace caffe数据层就是调用了封装层的DB来读取数据,此外还简单封装了boost的线程库,然后自己封装了个阻塞队列。
0 0
- 梳理caffe代码data_reader(十一)
- 梳理caffe代码relu_layer(二十一)
- 梳理caffe代码internal_thread(九)
- 梳理caffe代码blocking_queue(十)
- 梳理caffe代码data_transformer(十二)
- 梳理caffe代码io(十三)
- 梳理caffe代码solver(十四)
- 梳理caffe代码sgd_solvers(十五)
- 梳理caffe代码im2col(十七)
- 梳理caffe代码base_conv_layer(十八)
- 梳理caffe代码conv_layer(十九)
- 梳理caffe代码pooling_layer(二十)
- 梳理caffe代码loss(二十二)
- 梳理caffe代码softmaxWithLoss(二十三)
- 梳理caffe代码loss(二十二)
- Caffe代码梳理笔记
- 梳理caffe代码layer_factory
- 梳理caffe代码math_functions
- hdu 1257 最少拦截系统
- Android 手动显示和隐藏软键盘
- mysql入门第三课
- 逻辑斯蒂回归模型与最大熵模型
- 蓝牙学习之旅——低功耗蓝牙之广播报文CONNECT_REQ
- 梳理caffe代码data_reader(十一)
- 基于GNURadio平台的DAB发射实现(1)
- CSS布局和排版要点
- Python项目实战:个人博客(1)
- linphone快速挂断后的程序崩溃问题
- 三个数从大到小排列
- java 乱码详解_jsp中pageEncoding、charset=UTF -8"、request.setCharacterEncoding("UTF-8")
- 关于华为校招面试的那些事儿
- C++ 设计模式 —— 策略模式(Strategy)