Caffe源码解读(十二):自定义数据输入层
来源:互联网 发布:怎么知道ftp端口是多少 编辑:程序博客网 时间:2024/05/16 14:55
第1,3,4,5步跟上一节的自定义神经层的一样。
数据输入层需要重写三个函数:
1. DataLayerSetUp:定义好从prototxt读入的参数名和容器的规格(设好N,K,H,W)
2. ShuffleImages:打乱顺序
3. load_batch:把图片读入到内存
代码及解读如下:
#ifdef USE_OPENCV#include <opencv2/core/core.hpp>#include <fstream> // NOLINT(readability/streams)#include <iostream> // NOLINT(readability/streams)#include <string>#include <utility>#include <vector>#include "caffe/data_transformer.hpp"#include "caffe/layers/base_data_layer.hpp"#include "caffe/layers/image_data_layer.hpp"#include "caffe/util/benchmark.hpp"#include "caffe/util/io.hpp"#include "caffe/util/math_functions.hpp"#include "caffe/util/rng.hpp"namespace caffe {template <typename Dtype>ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() { this->StopInternalThread();}//DataLayerSetUp:定义好从prototxt读入的参数名和容器的规格(设好N, K, H, W)template <typename Dtype>void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { /* 读取在prototxt里面的设置数据: 1、caffe.proto中的LayerParameter中定义了ImageDataParameter类型的image_data_param变量 2、ImageDataParameter类中,定义了new_height、new_width、is_color、root_folder */ //layer_param_是Layer类中定义的protected变量; //Layer类在layer.hpp中定义,Layer没有继承任何其他类; //Layer类中定义了LayerParameter类型的变量layer_param_; //LayerParameter在caffe.proto定义,LayerParameter中定义了“optional ImageDataParameter image_data_param = 115;” //ImageDataParameter也在caffe.proto中定义,ImageDataParameter中定义了new_height、new_width、is_color、root_folder const int new_height = this->layer_param_.image_data_param().new_height(); const int new_width = this->layer_param_.image_data_param().new_width(); const bool is_color = this->layer_param_.image_data_param().is_color(); string root_folder = this->layer_param_.image_data_param().root_folder(); CHECK((new_height == 0 && new_width == 0) || (new_height > 0 && new_width > 0)) << "Current implementation requires " "new_height and new_width to be set at the same time."; // Read the file with filenames and labels /* 读取lmdb文件,并把data和label配对,存储到lines_里面,lines_在image_data_layer.hpp由我们自己定义,不是由prototxt生成。 */ //source跟new_height、new_width一样在ImageDataParameter中定义 const string& source = this->layer_param_.image_data_param().source(); LOG(INFO) << "Opening file " << source; //读取lmdb文件 //string类的c_str()函数,返回string的内含字符串,创建一个stream类对象infile,从硬盘读数据到内存。 std::ifstream infile(source.c_str()); string line; size_t pos; int label; while (std::getline(infile, line)) { pos = line.find_last_of(' '); //find_last_of:查找最近一个空格的位置。每一行的data和label是由空格分开,pos就是空格的位置。 label = atoi(line.substr(pos + 1).c_str()); //substr:取子字符串,从pos+1(也就是label的首字母)到行尾。取label的字符串表示。 lines_.push_back(std::make_pair(line.substr(0, pos), label)); //lines_在ImageDataLayer类中由自己定义的变量; //类型为vector<std::pair<std::string, int> >:std::pair主要的作用是将两个数据组合成一个数据 //make_pair:生成pair对象 //push_back:vector的函数,把变量装入vector中。 } CHECK(!lines_.empty()) << "File is empty"; /* 使用shuffle打乱顺序 */ if (this->layer_param_.image_data_param().shuffle()) { // randomly shuffle data LOG(INFO) << "Shuffling data"; const unsigned int prefetch_rng_seed = caffe_rng_rand(); //生成一个随机数作为种子,caffe_rng_rand在math_functions.cpp中定义 prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); //prefetch_rng_在ImageDataLayer类中由自己定义,类型为shared_ptr<Caffe::RNG> //shared_ptr请参见《笔记.doc》 ShuffleImages(); //本类的函数,根据prefetch_rng_随机数打乱容器lines_的顺序 } LOG(INFO) << "A total of " << lines_.size() << " images."; /* 使用skip随机跳过一些图片 */ lines_id_ = 0; //lines_id_:由在ImageDataLayer类中由自己定义 // Check if we would need to randomly skip a few data points if (this->layer_param_.image_data_param().rand_skip()) { //rand_skip指定随机跳过的间隔,跟new_height、new_width一样在ImageDataParameter中定义 unsigned int skip = caffe_rng_rand() % this->layer_param_.image_data_param().rand_skip(); // LOG(INFO) << "Skipping first " << skip << " data points."; CHECK_GT(lines_.size(), skip) << "Not enough points to skip"; lines_id_ = skip; } // Read an image, and use it to initialize the top blob. /* 代码核心:加载图片 */ cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, //root_folder:根目录; lines_:类型为<data,label>的容器;first指的是data,也就是图片 new_height, new_width, is_color); //图片的高、宽、通道数 CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first; // Use data_transformer to infer the expected blob shape from a cv_image. vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img); //data_transformer_:在BaseDataLayer类中定义,类型为shared_ptr<DataTransformer<Dtype> > //DataTransformer类:在data_transformer.hpp中定义,作用是将常用变换应用于输入数据,例如缩放,镜像,减去图像平均值。 //InferBlobShape函数:在DataTransformer类中定义,推断Blob的shape //详情见http://blog.csdn.net/xizero00/article/details/50905685 //top_shape:以mnist的图片为例,top_shape的值将是:[1,1,28,28]。即1张图,单通道,高28,宽28; this->transformed_data_.Reshape(top_shape); //更改Blob的维度大小到图片的大小 // Reshape prefetch_data and top[0] according to the batch_size. const int batch_size = this->layer_param_.image_data_param().batch_size(); //读取batch_size大小, CHECK_GT(batch_size, 0) << "Positive batch size required"; top_shape[0] = batch_size; //设置第一个维度的大小为batch_size,即每次迭代有batch_size个图片 for (int i = 0; i < this->PREFETCH_COUNT; ++i) { //PREFETCH_COUNT:静态变量,预取的batch数,默认为3 this->prefetch_[i].data_.Reshape(top_shape); //把prefetch_的data_做reshape到top_shape大小 //prefetch_[PREFETCH_COUNT]:类型Batch<Dtype> ,Batch类只有两个public的变量data_和label_ } top[0]->Reshape(top_shape); //top即ImageDataLayer要输出的数据,由<data,label>组成,top[0]表示数据,top[1]表示label。data和label都是blob类型。 LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label vector<int> label_shape(1, batch_size); //label的空间尺寸:1表示1维空间 top[1]->Reshape(label_shape); for (int i = 0; i < this->PREFETCH_COUNT; ++i) { this->prefetch_[i].label_.Reshape(label_shape); }}template <typename Dtype>void ImageDataLayer<Dtype>::ShuffleImages() { caffe::rng_t* prefetch_rng = static_cast<caffe::rng_t*>(prefetch_rng_->generator()); //generator():随机数生成器 shuffle(lines_.begin(), lines_.end(), prefetch_rng); //shuffle:根据随机数打乱vector的顺序,为什么}// This function is called on prefetch thread//把图片读到内存template <typename Dtype>void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); //读取prototxt的文件配置,和SetUp函数操作一致。 ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); const int new_height = image_data_param.new_height(); const int new_width = image_data_param.new_width(); const bool is_color = image_data_param.is_color(); string root_folder = image_data_param.root_folder(); // Reshape according to the first image of each batch // on single input batches allows for inputs of varying dimension. cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, new_height, new_width, is_color); CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first; // Use data_transformer to infer the expected blob shape from a cv_img. vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img); //计算cv_img大小,和SetUp函数操作一致。 this->transformed_data_.Reshape(top_shape); //设置blob空间尺寸,和SetUp函数操作一致。 // Reshape batch according to the batch_size. top_shape[0] = batch_size; //batch_size张图片,和SetUp函数操作一致。 batch->data_.Reshape(top_shape); //batch类只有两个public的变量data_和label_,都为Blob类型。 Dtype* prefetch_data = batch->data_.mutable_cpu_data(); //mutable_cpu_data()返回data_的地址 Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); //样本个数 for (int item_id = 0; item_id < batch_size; ++item_id) { // get a blob timer.Start(); CHECK_GT(lines_size, lines_id_); cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, //读取第lines_id_个样本的数据,转化为Mat型 new_height, new_width, is_color); CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first; read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image int offset = batch->data_.offset(item_id); //获取item_id个图像数据的偏移量 this->transformed_data_.set_cpu_data(prefetch_data + offset); //set_cpu_data指定数据地址为prefetch_data this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); //把cv_img数据转换到transformed_data_指定的data地址prefetch_data + offset trans_time += timer.MicroSeconds(); prefetch_label[item_id] = lines_[lines_id_].second; // go to the next iter lines_id_++; if (lines_id_ >= lines_size) { // We have reached the end. Restart from the first. DLOG(INFO) << "Restarting data prefetching from start."; lines_id_ = 0; if (this->layer_param_.image_data_param().shuffle()) { ShuffleImages(); } } } batch_timer.Stop(); DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";}INSTANTIATE_CLASS(ImageDataLayer);REGISTER_LAYER_CLASS(ImageData);} // namespace caffe#endif // USE_OPENCV
以ImageDataLayer层的使用:
layer { name: "data" type: "ImageData" //在ImageDataLayer.hpp中的type函数定义 top: "data" //由这两个top可知,ImageDataLayer会定义top[0]和top[1]两个输出,top[0]是数据,top[1]是label top: "label" transform_param { //定义了图像数据预处理的操作 mirror: false crop_size: 227 mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" } image_data_param { //这是我们要定义的source source: "examples/_temp/file_list.txt" batch_size: 50 new_height: 256 new_width: 256 } }
0 0
- Caffe源码解读(十二):自定义数据输入层
- Caffe 自定义数据输入层
- Roi_Pooling层caffe源码解读
- caffe源码学习(六) 自定义层
- (16)caffe总结之自定义数据输入层
- 【caffe】标准数据层输入
- Caffe源码解读:lrn_layer层原理
- 自定义数据输入层
- Caffe源码解读(十一):自定义一个layer
- caffe源码 之 数据层
- Caffe中mnist例子(一)自定义输入层
- 【深度学习框架Caffe学习与应用】第五课 自定义神经层和数据输入层
- SSD的caffe源码解读 -- 数据增强
- Caffe代码解读(五):数据层及参数
- Caffe源码解读(二):Blob类的源码解读
- Caffe源码解读(三):Layer类的源码解读
- Caffe源码解读(七):将图片数据转化为LMDB数据
- caffe 源码的解读(2)DataStructure
- Retrofit2.0使用详解&&封装
- 关于ggplot2画散点图、条形图的一些细节认识
- 如何利用 Chrome 开发者工具远程调试 Android 中的原生 WebView?
- 继承——java面向对象
- static静态
- Caffe源码解读(十二):自定义数据输入层
- js之 同一页面中的多表单提交
- Android-ABIFilter-Device supports x86,but APK only supports armeabi-v7a,armeabi,x86_64
- 异常处理
- 计算机解决问题的方法
- Java源码——对象序列化(对象的存储及读取)(Object Serialization)
- HDU 1753
- lucene、solr的介绍及区别
- Java Socket实战之一 单线程通信