Caffe2源码理解系列之IO
来源:互联网 发布:道路照明计算软件 编辑:程序博客网 时间:2024/06/17 00:56
Caffe2 IO
本文主要记录下我对Caffe2的输入输出部分源代码的理解。数据是以什么样的形式输入进网络的,训练过程中如何保存网络模型。与数据输入相关的Operator是DBReader, ImageInputOp, 与存储训练过程中保存模型相关信息的是SaveOp, LoadOp,以及一系列与序列化相关的工具类,比如BlobSerializer。下面分别介绍一下,如有理解错误,欢迎指出。PS,Caffe2的代码写得真心赞啊。
- DBReader
- ImageInputOp
- SaveOp
- LoadOp
- 总结
DBReader
如同Caffe1一样,一般情况下,在进行模型训练的时候,Caffe2也需要事先将数据转成特定格式的数据库,比如lmdb, leveldb。只不过Caffe2支持的数据库格式更加丰富,除了上述两种格式的db外,还有minidb, zmqdb, protodb, rocksdb等等。Caffe2中对lmdb的实现跟Caffe1有所不同,但功能是一样的。PS,个人以为Caffe1中的实现要优雅些,因为我直接在windows上用Caffe2自带的lmdb.cc来生成数据库时运行不通过,直接改成Caffe1中的就OK了。另外由于Caffe2在默认保存模型时候使用的是minidb, 所以简单地介绍下minidb。
DBReader封装了如何读取数据库的操作。注意在单机多GPU情况下DBReader只有一个实例,为各个GPU共享。在多机的情况下,每台机器有一个DBReader实例,通过DBReader中的成员变量shard_id_来标识该节点负责读取哪一部分的数据库。通常,每一台机器都会有一份完整的相同的数据库,当然也可以通过nfs将数据库从一台机器映射给其他机器。读取同一个数据库的时候。DBReader自动会对数据进行切片,保证每个节点的每个GPU读取数据库的不同部分,以此达到数据并行。DBReader的摘要如下:
class DBReader {...private: string db_type_; //数据库的类型,包括minidb,leveldb,lmdb等等 string source_; //数据库的路径 unique_ptr<DB> db_; //数据库对象 unique_ptr<Cursor> cursor_; //数据库游标 mutable std::mutex reader_mutex_;//单机多GPU环境下,应该是多线程进行训练,多线程共享同一个DBReader实例,因此需要用这个reader_mutex来控制对共享变量的访问。 uint32_t num_shards_; //单机环境下,该值为0,分布式环境下,该值为节点数目。 uint32_t shard_id_; //节点id,从0开始,单机情况下为0,依次递增, DISABLE_COPY_AND_ASSIGN(DBReader); public: void Open(const string& db_type, const string& source, const int32_t num_shards = 1, const int32_t shard_id = 0) { //打开数据库,该函数会在构造函数里被调用 cursor_.reset(); db_.reset(); db_type_ = db_type; source_ = source; db_ = CreateDB(db_type_, source_, READ); CAFFE_ENFORCE(db_, "Cannot open db: ", source_, " of type ", db_type_); InitializeCursor(num_shards, shard_id); }// for i = 0: batch_size, call Read void Read(string* key, string* value) const { CAFFE_ENFORCE(cursor_ != nullptr, "Reader not initialized."); std::unique_lock<std::mutex> mutex_lock(reader_mutex_);//这里注意,只对单机多GPU会阻塞,不同机器之间不会阻塞,因为是不同的DBReader实例,多机通信会通过rendezvous进行同步,比如redis _store_handler等。 *key = cursor_->key(); *value = cursor_->value(); // 在分布式环境下,由于一次有num_shards台机器参与读取数据,因此一次计算读取的数据量有num_shards * 每台机器读取的数据量,所以对于每一台机器而言,这里要跳过num_shards个记录,才是它下一次迭代应该读取的数据库位置 for (int s = 0; s < num_shards_; s++) { cursor_->Next(); if (!cursor_->Valid()) { MoveToBeginning(); break; } } } ...};
DB, Transaction, Cursor三个接口类定义了如何操作数据库。对于不同类型的数据库,会有相应的实现,比如针对lmdb,就有LMDB, LMDBTransaction, LMDBCursor,针对minidb,就有MiniDB, MiniDBTransaction, MiniDBCursor。从Caffe2中实现的lmdb,minidb, leveldb来看,读数据库只支持顺序读取,即cursor从头到尾顺序访问数据库,当访问到数据库末尾时候,cursor又从头开始,因此并不支持对数据库的随机访问。DB的摘要如下:
class DB { public: DB(const string& /*source*/, Mode mode) : mode_(mode) {} virtual ~DB() { } /** * Closes the database. */ virtual void Close() = 0; /** * Returns a cursor to read the database. The caller takes the ownership of * the pointer. */ virtual std::unique_ptr<Cursor> NewCursor() = 0; /** * Returns a transaction to write data to the database. The caller takes the * ownership of the pointer. */ virtual std::unique_ptr<Transaction> NewTransaction() = 0; protected: Mode mode_; //这个mode定义为enum Mode { READ, WRITE, NEW }; DISABLE_COPY_AND_ASSIGN(DB);};
minidb相关操作
minidb其实就是简单地封装了C语言中的文件IO调用, 没啥特别之处,直接把caffe2/core/db.cc中的代码贴出来。因为有这个minidb的存在,因此Caffe2就不像Caffe1中有辣么多依赖软件了。lmdb和leveldb对Caffe2来说就是可选的了。不过,minidb的功能肯定不如lmdb了(个人猜测,minidb的读写效率啊,估计也没有lmdb高)。
class MiniDBCursor : public Cursor { public: explicit MiniDBCursor(FILE* f, std::mutex* mutex) : file_(f), lock_(*mutex), valid_(true) { // We call Next() to read in the first entry. Next(); } ~MiniDBCursor() {} void Seek(const string& /*key*/) override { LOG(FATAL) << "MiniDB does not support seeking to a specific key."; } void SeekToFirst() override { fseek(file_, 0, SEEK_SET); CAFFE_ENFORCE(!feof(file_), "Hmm, empty file?"); // Read the first item. valid_ = true; Next(); } void Next() override { // First, read in the key and value length. if (fread(&key_len_, sizeof(int), 1, file_) == 0) { // Reaching EOF. VLOG(1) << "EOF reached, setting valid to false"; valid_ = false; return; } CAFFE_ENFORCE_EQ(fread(&value_len_, sizeof(int), 1, file_), 1); CAFFE_ENFORCE_GT(key_len_, 0); CAFFE_ENFORCE_GT(value_len_, 0); // Resize if the key and value len is larger than the current one. if (key_len_ > key_.size()) { key_.resize(key_len_); } if (value_len_ > value_.size()) { value_.resize(value_len_); } // Actually read in the contents. CAFFE_ENFORCE_EQ( fread(key_.data(), sizeof(char), key_len_, file_), key_len_); CAFFE_ENFORCE_EQ( fread(value_.data(), sizeof(char), value_len_, file_), value_len_); // Note(Yangqing): as we read the file, the cursor naturally moves to the // beginning of the next entry. } string key() override { CAFFE_ENFORCE(valid_, "Cursor is at invalid location!"); return string(key_.data(), key_len_); } string value() override { CAFFE_ENFORCE(valid_, "Cursor is at invalid location!"); return string(value_.data(), value_len_); } bool Valid() override { return valid_; } private: FILE* file_; std::lock_guard<std::mutex> lock_; bool valid_; int key_len_; vector<char> key_; int value_len_; vector<char> value_;};class MiniDBTransaction : public Transaction { public: explicit MiniDBTransaction(FILE* f, std::mutex* mutex) : file_(f), lock_(*mutex) {} ~MiniDBTransaction() { Commit(); } void Put(const string& key, const string& value) override { int key_len = key.size(); int value_len = value.size(); CAFFE_ENFORCE_EQ(fwrite(&key_len, sizeof(int), 1, file_), 1); CAFFE_ENFORCE_EQ(fwrite(&value_len, sizeof(int), 1, file_), 1); CAFFE_ENFORCE_EQ( fwrite(key.c_str(), sizeof(char), key_len, file_), key_len); CAFFE_ENFORCE_EQ( fwrite(value.c_str(), sizeof(char), value_len, file_), value_len); } void Commit() override { if (file_ != nullptr) { CAFFE_ENFORCE_EQ(fflush(file_), 0); file_ = nullptr; } } private: FILE* file_; std::lock_guard<std::mutex> lock_; DISABLE_COPY_AND_ASSIGN(MiniDBTransaction);};class MiniDB : public DB { public: MiniDB(const string& source, Mode mode) : DB(source, mode), file_(nullptr) { switch (mode) { case NEW: file_ = fopen(source.c_str(), "wb"); break; case WRITE: file_ = fopen(source.c_str(), "ab"); fseek(file_, 0, SEEK_END); break; case READ: file_ = fopen(source.c_str(), "rb"); break; } CAFFE_ENFORCE(file_, "Cannot open file: " + source); VLOG(1) << "Opened MiniDB " << source; } ~MiniDB() { Close(); } void Close() override { if (file_) { fclose(file_); } file_ = nullptr; } unique_ptr<Cursor> NewCursor() override { CAFFE_ENFORCE_EQ(this->mode_, READ); return make_unique<MiniDBCursor>(file_, &file_access_mutex_); } unique_ptr<Transaction> NewTransaction() override { CAFFE_ENFORCE(this->mode_ == NEW || this->mode_ == WRITE); return make_unique<MiniDBTransaction>(file_, &file_access_mutex_); } private: FILE* file_; // access mutex makes sure we don't have multiple cursors/transactions // reading the same file. std::mutex file_access_mutex_;};
ImageInputOp
编译这个ImageInputOp需要opencv的支持。这个operator就是真正把数据库中存储的数据转换成CNN训练用的图片了。它就类似于Caffe1中的BasePrefetchingDataLayer,但ImageInputOp的功能比BasePrefetchDataLayer强大得多。除了支持像BasePrefetchDataLayer那样的随机裁剪,镜像,resize图片大小之外,还支持更加丰富的Data Augmentation, 比如颜色扰动,对比度,饱和度等,googlenet和resnet中做的数据增广都已经实现了。另一个显著的地方是,ImageInputOp除了支持单标签外,也支持多标签。ImageInputOp的输出数据格式是NHWC的形式,虽然Caffe2支持NHWC,NCHW两种数据格式,它默认支持的数据格式Caffe1的数据格式,即NCHW。默认情况下,当用python来训练时,调用ImageInput时,arg_scope的order缺省情况是NCHW, Caffe2的python接口会自动添加NHWC2NCHWOp进行数据排布转换。
ImageInputOp是一个典型的单生产者,单消费者,只有一个缓冲区容量的异步OP。对于batchsize个样本的解码进行数据增广操作又是多线程并行的。下面分别介绍一下:
生产者消费者模式体现在ImageInputOp的父类PrefetchOp中。
消费者
//每次前传时候,会调用这个Run方法,通知生产者进行生产数据。但这里为啥没有一个如同Caffe1一样设置一个大小为PREFETCH_COUNT容量的缓冲区,让生产者不停生产,缓冲区满了后再停止,而是每次都现消费现来生产了?想不通哈,请大神帮助解释一下。 PS,不过我自己在训练的时候,并没法先卡在IO,估计这个IO过程很快吧。 bool Run(int /* unused */ /*stream_id*/) override { if (!prefetch_thread_) { prefetch_thread_.reset(new std::thread([this] { this->PrefetchWorker(); })); } context_.SwitchToDevice(0); std::unique_lock<std::mutex> lock(prefetch_access_mutex_); while (!prefetched_) consumer_.wait(lock); if (!prefetch_success_) { LOG(ERROR) << "Prefetching failed."; return false; } if (!CopyPrefetched()) { //CopyPrefetched表示消费,正确返回就表示消费完了,然后就通知生产者继续生产。 LOG(ERROR) << "Error when copying prefetched data."; return false; } prefetched_ = false; context_.FinishDeviceComputation(); producer_.notify_one(); return true; }
生产者
void PrefetchWorker() { context_.SwitchToDevice(); std::unique_lock<std::mutex> lock(prefetch_access_mutex_); while (prefetched_) producer_.wait(lock); while (!finalize_) { // We will need to run a FinishDeviceComputation() call because the // prefetcher thread and the main thread are potentially using different // streams (like on GPU). try { prefetch_success_ = Prefetch();//Prefetch就代表生产数据了,它是个虚函数,ImageInputOp会实现之。 context_.FinishDeviceComputation(); } catch (const std::exception& e) { // TODO: propagate exception_ptr to the caller side LOG(ERROR) << "Prefetching error " << e.what(); prefetch_success_ = false; } prefetched_ = true; consumer_.notify_one(); while (prefetched_) //我理解的单生产者单消费者就在此,Pretch完,就等待消费者消费,直到消费完了,难道说,可以在Pretch中扩大缓冲区么? producer_.wait(lock); } }
在ImageInputOp中的多线程解码转换数据部分,就体现在成员变量thread_pool_了,它是个线程池TaskThreadPool的智能指针。下面是对解码部分的一个简单说明
for (int item_id = 0; item_id < batch_size_; ++item_id) { ..... //先做一些必要的准备操作 thread_pool_->runTaskWithID(std::bind(&ImageInputOp<Context>::DecodeAndTransform, this, std::string(value), image_data, item_id, channels, std::placeholders::_1));//往线程池里面添加任务,启动线程来计算。 } thread_pool_->waitWorkComplete();//等待解码完成 ...
来看看Caffe2中这个线程池是咋实现的吧,直接复制之。以前听说过线程池,但从未实现过,看看这代码,也学到不少东西。
class TaskThreadPool{ private: struct task_element_t { bool run_with_id; const std::function< void() > no_id; const std::function< void(std::size_t) > with_id; explicit task_element_t(const std::function< void() >& f) : run_with_id(false), no_id(f), with_id(nullptr) { } explicit task_element_t(const std::function< void(std::size_t) >& f) : run_with_id(true), no_id(nullptr), with_id(f) { } }; std::queue<task_element_t> tasks_; std::vector<std::thread> threads_; std::mutex mutex_; std::condition_variable condition_; std::condition_variable completed_; bool running_; bool complete_; std::size_t available_; std::size_t total_; public: /// @brief Constructor. explicit TaskThreadPool(std::size_t pool_size) : threads_(pool_size), running_(true), complete_(true), available_(pool_size), total_(pool_size) { for ( std::size_t i = 0; i < pool_size; ++i ) {//线程池里面共有pool_size个工作线程在等待tasks_中的任务 threads_[i] = std::thread( std::bind(&TaskThreadPool::main_loop, this, i)); } } /// @brief Destructor. ~TaskThreadPool() { // Set running flag to false then notify all threads. { std::unique_lock< std::mutex > lock(mutex_); running_ = false; condition_.notify_all(); } try { for (auto& t : threads_) { t.join(); } } // Suppress all exceptions. catch (const std::exception&) {} } /// @brief Add task to the thread pool if a thread is currently available. template <typename Task> void runTask(Task task) { std::unique_lock<std::mutex> lock(mutex_); // Set task and signal condition variable so that a worker thread will // wake up and use the task. tasks_.push(task_element_t(static_cast<std::function< void() >>(task))); complete_ = false; condition_.notify_one(); } template <typename Task> void runTaskWithID(Task task) { std::unique_lock<std::mutex> lock(mutex_); // Set task and signal condition variable so that a worker thread will // wake up and use the task. tasks_.push(task_element_t(static_cast<std::function< void(std::size_t) >>( task))); complete_ = false; condition_.notify_one(); } /// @brief Wait for queue to be empty void waitWorkComplete() { std::unique_lock<std::mutex> lock(mutex_); while (!complete_) completed_.wait(lock); } private: /// @brief Entry point for pool threads. void main_loop(std::size_t index) { while (running_) { // Wait on condition variable while the task is empty and // the pool is still running. std::unique_lock<std::mutex> lock(mutex_); while (tasks_.empty() && running_) { condition_.wait(lock); } // If pool is no longer running, break out of loop. if (!running_) break; // Copy task locally and remove from the queue. This is // done within its own scope so that the task object is // destructed immediately after running the task. This is // useful in the event that the function contains // shared_ptr arguments bound via bind. { auto tasks = tasks_.front(); tasks_.pop(); // Decrement count, indicating thread is no longer available. --available_; lock.unlock();//由于mutex已经被锁住了,需要释放之,以让其他线程能够获得任务,不然线程就串行了,无法并行。 // Run the task. try { if (tasks.run_with_id) { tasks.with_id(index); } else { tasks.no_id(); } } // Suppress all exceptions. catch ( const std::exception& ) {} // Update status of empty, maybe // Need to recover the lock first lock.lock(); // Increment count, indicating thread is available. ++available_; if (tasks_.empty() && available_ == total_) { complete_ = true; completed_.notify_one(); } } } // while running_ }};
SaveOp
在训练过程中,一般会每隔一定的迭代次数保存将当前模型保存到硬盘上。在Caffe2中与保存模型有关的save_to_db函数,它是一个Python函数,封装了应该保存的信息,以方便加载用。save_to_db调用的是C++端的SaveOp。
模型中需要保存的信息有:
模型参数 ——通过ModelHelper的params属性可以获得模型参数,比如卷积的卷积核bias,FC的weight,bias, BN的estimated mean和estimated var等等。
模型定义 ——网络的Op集合。如果是depoly的话,就不需要保存gradient operators, 否则需要保存graients operators。
当创建一个Operators,比如创建ConvOp, 该ConvOp需要卷积核conv_w以及bias conv_b会自动创建。这些参数名字会自动添加进ModelHelper的param_init_net中,而这个param_init_net就可以被视为包含网络参数的定义以及如何初始化这些参数的prototxt,比如调用各种具体的初始化算法如Xavier, Gaussian, MSRA等等来填充参数tensor。这就是为啥在训练真正开始之前,需要先调用workspace.RunNetOnce(model_helper_obj.param_init_net)的原因。
SaveOp Run方法被调用时,其实就是对输入的vecor<const Blob*>
依次调用Serialize进行序列化,保存到硬盘上。Serialize的函数原型为:
//每次保存的都是带名字的键值对,这也符合Caffe2的存储设计思想,即所有的内存区域都要有名字,比如workspace的map<string, unique_ptr<Blob> > blob_map_。这个acceptor就是负责和具体的DB打交道的函数,它将转化后的字符串输出到真正的DB中,完成保存到硬盘的操作,完全类似Caffe1中那个convert_imageset。void Blob::Serialize(const string& name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size) const
在Caffe2存储部分说过,Blob是一个容器,它可以容纳任意类型,比如string, tensor,网络的具体定义比如prototxt就是Blob<string>,
它的序列化就是调用StringSerialzer,没啥特别之处。比较有意思部分是当Blob存储的是Tensor时的序列化,毕竟网络参数都是Tensor。这个序列化的过程其实就是把Tensor中的数据转换成google protobuf。上面那个acceptor的输入就是这个google protobuf的序列化字符串。下面是Tensor的序列化的代码,直接复制过来。
template <class Context>void TensorSerializer<Context>::SerializeWithChunkSize(const Blob& blob, const string& name, BlobSerializerBase::SerializationAcceptor acceptor, int chunk_size) { CAFFE_ENFORCE(blob.IsType<Tensor<Context>>()); const auto& tensor = blob.template Get<Tensor<Context>>(); if (chunk_size == kNoChunking) { chunk_size = tensor.size() + 1; // to account for empty tensors } else if (chunk_size == kDefaultChunkSize) { chunk_size = FLAGS_caffe2_tensor_chunk_size; } auto processChunk = [&](int64_t chunkStart) { BlobProto blob_proto; blob_proto.set_name(name); blob_proto.set_type(kTensorBlobType); TensorProto& proto = *blob_proto.mutable_tensor(); proto.set_name(name); this->Serialize( tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size); acceptor( MakeString(name, kChunkIdSeparator, chunkStart / chunk_size), blob_proto.SerializeAsString()); };#ifndef __ANDROID__ std::vector<std::future<void>> futures; // Poorman's IOBound ThreadPool //对于超大的Tensor保存,又是多线程并行序列化啊,真是追求性能到极致了。再一次膜拜fb的工程师,贡献了一份如此漂亮的工业级代码。学习了。 SimpleQueue<size_t> chunkQueue;//线程安全队列 auto task = [&]() { size_t chunkStart; while (chunkQueue.Pop(&chunkStart)) { processChunk(chunkStart); } }; if (tensor.size() > chunk_size) {//这里就是多线程开始的地方 for (int i = 0; i < FLAGS_caffe2_max_tensor_serializer_threads; ++i) { futures.emplace_back(std::async(std::launch::async, task)); } }#endif VLOG(1) << "Serializing blob " << name; // Serialize whole vector. If vector is empty, it's shape still needs to be // serialized in empty proto for (size_t chunkBegin = 0; chunkBegin < std::max(tensor.size(), static_cast<TIndex>(1)); chunkBegin += chunk_size) { VLOG(2) << "Starting a chunk at " << chunkBegin;#ifndef __ANDROID__ if (tensor.size() > chunk_size) { chunkQueue.Push(chunkBegin);//Tensor太大了,分块,对于每一块都扔给线程池来去序列化。 } else { // Sync mode for small tensors processChunk(chunkBegin); }#else // Since Android does not have std::future, we will always do sync mode processChunk(chunkBegin);#endif }#ifndef __ANDROID__ chunkQueue.NoMoreJobs(); for (auto& fut : futures) { fut.get(); }#endif}
LoadOp
LoadOp对应于反序列化,主要。这里主要说明下python端的prepare_prediction_net
函数。保存到数据库中的模型有3个net,即global_init_net, predict_init_net, predict_net。
global_init_net ——模型参数加载进workspace就是通过这个global_init_net的,它存储了网络参数的名字。
predict_init_net——输入输出,指定网络输入blob,输出blob的名字,以及输入输出形状的定义和置0,同样,它存储在workspace中。
predict_net——作用跟Caffe1中的deploy.prototxt差不多。
进行预测时候,加载顺序是global_init_net,先将参数加载进workspace,然后加载predict_init_net,初始化输入输出,并置0。最后是根据predict_net来构件网络结构,创建一个一个Op。
总结
Caffe2中的ImageInputOp中的decode多线程部分,存储时的TensorSerializer中多线程序列化,是值得学习的地方。
- Caffe2源码理解系列之IO
- Caffe2源码理解系列之存储
- mybatis 源码系列 组件之 io
- JAVA IO源码学习系列之InputStream
- JAVA IO源码学习系列之ByteArrayInputStream
- MySQL系列:innodb源码分析之文件IO
- MySQL系列:innodb源码分析之文件IO
- caffe2之caffe_translator
- 深入理解Tomcat系列之二:源码调试环境搭建
- 深入理解Spring系列之九:DispatcherServlet初始化源码分析
- 拆轮子系列之理解GreenDao框架源码
- IO系列之File
- Caffe2
- Caffe2
- Android 源码系列之<八>从源码的角度深入理解缓存策略之LruCache
- Java IO系列之初始IO
- Android 源码系列之<一>从源码的角度深入理解ImageView的ScaleType属性
- Android 源码系列之<七>从源码的角度深入理解IntentService及HandlerThread
- 一对多广播,Fragment延迟加载
- 还原bak后缀的数据库文件如何操作
- HC9S12X 定义及访问直接寻址区
- 160. Intersection of Two Linked Lists
- module + 异常
- Caffe2源码理解系列之IO
- 64. Minimum Path Sum
- Spring报错java.lang.IllegalStateException: BeanFactory not initialized or already closed -call 'refres
- 关于JavaScript闭包中this对象(colsure)
- 用GAN来做图像生成,这是最好的方法
- Android:JNI 与 NDK到底是什么?(含实例教学)
- 彻底解决Webpack打包性能问题
- VTK学习(十六)三角剖分
- 圆、长方形的面积和周长