caffe代码阅读4:DataTransformer以及io的实现细节-2016.3.16
来源:互联网 发布:ubuntu zip压缩命令 编辑:程序博客网 时间:2024/05/18 06:42
一、DataTransformer的作用简介
该类主要负责对数据进行预处理,将Datum、const vector<Datum>、cv::Mat&、vector<cv::Mat> 、Blob<Dtype>*类型的数据变换到目标大小的blob。
此外还负责根据参数中指定的预处理参数推断出处理后的数据的shape。
在正式介绍之前,先给个例子:
layer { name: "jointimagedata" type: "JointImage" top: "jointimagedata" top: "label" include { phase: TEST } transform_param { mirror: true crop_size: 227 mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" } slidewindow_param { root_folder: "D:/数据集/FLIC/FLIC-full" filelistpath: "/imglist.txt" batch_size: 300 }}
上述配置文件中就包含了transform_param这个参数,利用该参数可以实现crop,mirror,减去均值等功能。
该类用到了TransformationParameter。
其在caffe.proto的定义为
// Message that stores parameters used to apply transformation// to the data layer's datamessage TransformationParameter { // For data pre-processing, we can do simple scaling and subtracting the // data mean, if provided. Note that the mean subtraction is always carried // out before scaling. optional float scale = 1 [default = 1]; // Specify if we want to randomly mirror data. optional bool mirror = 2 [default = false]; // Specify if we would like to randomly crop an image. optional uint32 crop_size = 3 [default = 0]; // mean_file and mean_value cannot be specified at the same time optional string mean_file = 4; // if specified can be repeated once (would substract it from all the channels) // or can be repeated the same number of times as channels // (would subtract them from the corresponding channel) repeated float mean_value = 5; // Force the decoded image to have 3 color channels. optional bool force_color = 6 [default = false]; // Force the decoded image to have 1 color channels. optional bool force_gray = 7 [default = false];}
二、DataTransformer类的详细介绍
1)构造函数
// 构造函数
explicit DataTransformer(const TransformationParameter& param, Phase phase); virtual ~DataTransformer() {}
2)成员变量
// 变换所使用的参数 TransformationParameter param_; // 随机数生成器的种子 shared_ptr<Caffe::RNG> rng_; // 是训练还是测试? Phase phase_; // 数据均值 blob Blob<Dtype> data_mean_; // 数据均值blob的容器 vector<Dtype> mean_values_;
3)成员函数
// 初始化随机数生成器,因为在对数据进行变换的时候有可能用到,比如说打乱数据的输入顺序 void InitRand(); // 对Datum的数据进行变换,放入到transformed_blob中 void Transform(const Datum& datum, Blob<Dtype>* transformed_blob); // 对Datum容器的数据进行变换翻入到transformed_blob void Transform(const vector<Datum> & datum_vector, Blob<Dtype>* transformed_blob); // 如果定义OpenCV还可能对mat容器数据类型的数据进行变换 void Transform(const vector<cv::Mat> & mat_vector, Blob<Dtype>* transformed_blob); // 将opencv读取的单个图像转换到blob中去 void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob); // 将输入的blob进行变换,可能是取出blob的中的一部分数据到新的blob void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob); // 根据Datum获取blob的形状 vector<int> InferBlobShape(const Datum& datum); // 根据Datum容器获取blob的形状 vector<int> InferBlobShape(const vector<Datum> & datum_vector); // 根据Mat容器获取blob的形状 vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector); // 根据Mat获取blob的形状 vector<int> InferBlobShape(const cv::Mat& cv_img);// 生成从0到n-1的服从均匀分布的随机数,要求继承他的都必须实现如何生成随机数 virtual int Rand(int n); // 将给定的Datum进行转换 void Transform(const Datum& datum, Dtype* transformed_data);
4)具体函数的实现:
首先是构造函数
在介绍构造函数之前不得不先贴出BlobShape和、BlobProto这两个结构体的在caffe.proto中的定义。message BlobShape { repeated int64 dim = 1 [packed = true]; //blob的形状}message BlobProto { optional BlobShape shape = 7; repeated float data = 5 [packed = true]; // 前向传播的数据 repeated float diff = 6 [packed = true]; // 反向传播的数据 repeated double double_data = 8 [packed = true]; // double类型的前向传播的数据 repeated double double_diff = 9 [packed = true]; // 依次类推 // 4D dimensions -- deprecated. Use "shape" instead. // 下面是为了兼容 optional int32 num = 1 [default = 0]; optional int32 channels = 2 [default = 0]; optional int32 height = 3 [default = 0]; optional int32 width = 4 [default = 0];}template<typename Dtype>DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Phase phase) : param_(param), phase_(phase) { // check if we want to use mean_file if (param_.has_mean_file()) { CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; const string& mean_file = param.mean_file(); if (Caffe::root_solver()) { LOG(INFO) << "Loading mean file from: " << mean_file; } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); } // check if we want to use mean_value if (param_.mean_value_size() > 0) { CHECK(param_.has_mean_file() == false) << "Cannot specify mean_file and mean_value at the same time"; for (int c = 0; c < param_.mean_value_size(); ++c) { mean_values_.push_back(param_.mean_value(c)); } }}
具体的实现如下:
#ifdef USE_OPENCV#include <opencv2/core/core.hpp>#endif // USE_OPENCV#include <string>#include <vector>#include "caffe/data_transformer.hpp"#include "caffe/util/io.hpp"#include "caffe/util/math_functions.hpp"#include "caffe/util/rng.hpp"namespace caffe {// 构造函数template<typename Dtype>DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Phase phase) : param_(param), phase_(phase) { // check if we want to use mean_file // 判断是否有平均值文件 if (param_.has_mean_file()) { CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; // 平均值文件的路径 const string& mean_file = param.mean_file(); if (Caffe::root_solver()) { LOG(INFO) << "Loading mean file from: " << mean_file; } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); } // check if we want to use mean_value if (param_.mean_value_size() > 0) { CHECK(param_.has_mean_file() == false) << "Cannot specify mean_file and mean_value at the same time"; for (int c = 0; c < param_.mean_value_size(); ++c) { mean_values_.push_back(param_.mean_value(c)); } }}template<typename Dtype>void DataTransformer<Dtype>::Transform(const Datum& datum, Dtype* transformed_data) { // 参考TransformationParameter的定义 const string& data = datum.data(); const int datum_channels = datum.channels();//数据的channel const int datum_height = datum.height();//数据的行数 const int datum_width = datum.width();// 数据的列数 const int crop_size = param_.crop_size();// crop大小 const Dtype scale = param_.scale();// 缩放比例 const bool do_mirror = param_.mirror() && Rand(2);// 该参数用于在镜像位置对数据处理 const bool has_mean_file = param_.has_mean_file();// 是否有均值文件 const bool has_uint8 = data.size() > 0;// 数据是否为uint8还是float类型的 const bool has_mean_values = mean_values_.size() > 0;// 是否有每个channel的均值 // 检查合法性 CHECK_GT(datum_channels, 0); CHECK_GE(datum_height, crop_size); CHECK_GE(datum_width, crop_size); Dtype* mean = NULL; if (has_mean_file) {// 检查mean_file是否与数据的参数一致 CHECK_EQ(datum_channels, data_mean_.channels()); CHECK_EQ(datum_height, data_mean_.height()); CHECK_EQ(datum_width, data_mean_.width()); mean = data_mean_.mutable_cpu_data(); } if (has_mean_values) { CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) << "Specify either 1 mean_value or as many as channels: " << datum_channels; if (datum_channels > 1 && mean_values_.size() == 1) { // Replicate the mean_value for simplicity for (int c = 1; c < datum_channels; ++c) { mean_values_.push_back(mean_values_[0]); } } } int height = datum_height; int width = datum_width; // 根据是否需要crop来生成h_off和w_off int h_off = 0; int w_off = 0; if (crop_size) {// 如果crop_size不为0 height = crop_size; width = crop_size; // We only do random crop when we do training. // 在训练的时候随机crop图像块,这里需要自己实现Rand这个函数来确定是如何随机的 if (phase_ == TRAIN) { h_off = Rand(datum_height - crop_size + 1);// 产生从0到datum_height - crop_size的随机数 w_off = Rand(datum_width - crop_size + 1); } else {// 测试的时候不用随机,取图像的中心 h_off = (datum_height - crop_size) / 2; w_off = (datum_width - crop_size) / 2; } } // 对数据进行变换,主要是将原来的像素值减去均值,然后乘以scale这么一个操作 // 如果需要crop则最终转换的Blob的大小即为crop*crop // 如果不是,则最终的Blob大小即为datum_height*datum_width Dtype datum_element; int top_index, data_index; for (int c = 0; c < datum_channels; ++c) { for (int h = 0; h < height; ++h) { for (int w = 0; w < width; ++w) { data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;// 获取数据的索引 if (do_mirror) {// 是否需要在镜像位置转换 top_index = (c * height + h) * width + (width - 1 - w);//在宽这个坐标上做文章,来实现镜像 } else {// top_index = (c * height + h) * width + w; } if (has_uint8) {// 数据如果是uint8则进行转换 datum_element = static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); } else {// 否则就是float datum_element = datum.float_data(data_index); } if (has_mean_file) {// 如果有mean_file,则原来的像素值减去均值,然后乘以scale transformed_data[top_index] = (datum_element - mean[data_index]) * scale; } else { if (has_mean_values) {// 否则减去该channel的均值(每个channel有其一个均值),然后乘以scale transformed_data[top_index] = (datum_element - mean_values_[c]) * scale; } else {// 否则如果没有均值那么就直接乘以scale即可 transformed_data[top_index] = datum_element * scale; } } } } }}template<typename Dtype>void DataTransformer<Dtype>::Transform(const Datum& datum, Blob<Dtype>* transformed_blob) { // If datum is encoded, decoded and transform the cv::image. if (datum.encoded()) {// 检查是否编码了,如果是则解码#ifdef USE_OPENCV // 先检查是不是两个属性都设置, 如果是则说明参数设置有误 CHECK(!(param_.force_color() && param_.force_gray())) << "cannot set both force_color and force_gray"; cv::Mat cv_img; if (param_.force_color() || param_.force_gray()) { // 如果强制彩色或者强制灰度图像一个成立则使用DecodeDatumToCVMat解码 // If force_color then decode in color otherwise decode in gray. cv_img = DecodeDatumToCVMat(datum, param_.force_color()); } else {// 否则使用DecodeDatumToCVMatNative解码 cv_img = DecodeDatumToCVMatNative(datum); } // Transform the cv::image into blob. // 变换 return Transform(cv_img, transformed_blob);#else LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";#endif // USE_OPENCV } else {// 如果没有编码则,检查force_color和force_gray是否设置,如果设置则不合法,因为该选项只适合于编码后的数据 if (param_.force_color() || param_.force_gray()) { LOG(ERROR) << "force_color and force_gray only for encoded datum"; } } const int crop_size = param_.crop_size(); const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); // Check dimensions. const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); const int num = transformed_blob->num(); CHECK_EQ(channels, datum_channels); CHECK_LE(height, datum_height); CHECK_LE(width, datum_width); CHECK_GE(num, 1); if (crop_size) { CHECK_EQ(crop_size, height); CHECK_EQ(crop_size, width); } else { CHECK_EQ(datum_height, height); CHECK_EQ(datum_width, width); } // 继续变换数据 Dtype* transformed_data = transformed_blob->mutable_cpu_data(); Transform(datum, transformed_data);}template<typename Dtype>void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector, Blob<Dtype>* transformed_blob) { const int datum_num = datum_vector.size(); // 变换到的目标blob的形状 const int num = transformed_blob->num(); const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); CHECK_GT(datum_num, 0) << "There is no datum to add"; CHECK_LE(datum_num, num) << "The size of datum_vector must be no greater than transformed_blob->num()"; // 新建一个uni_blob,里面只有一个batch Blob<Dtype> uni_blob(1, channels, height, width); for (int item_id = 0; item_id < datum_num; ++item_id) { int offset = transformed_blob->offset(item_id); uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); Transform(datum_vector[item_id], &uni_blob); }}#ifdef USE_OPENCVtemplate<typename Dtype>void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector, Blob<Dtype>* transformed_blob) { // 获取mat的参数 const int mat_num = mat_vector.size(); const int num = transformed_blob->num(); const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); CHECK_GT(mat_num, 0) << "There is no MAT to add"; CHECK_EQ(mat_num, num) << "The size of mat_vector must be equals to transformed_blob->num()"; // 同上 Blob<Dtype> uni_blob(1, channels, height, width); for (int item_id = 0; item_id < mat_num; ++item_id) { int offset = transformed_blob->offset(item_id); uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); Transform(mat_vector[item_id], &uni_blob); }}// 如果是图像的话,需要减去均值乘以scale,判断是不是需要做镜像处理// 逻辑与前面类似template<typename Dtype>void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob) { const int crop_size = param_.crop_size(); const int img_channels = cv_img.channels(); const int img_height = cv_img.rows; const int img_width = cv_img.cols; // Check dimensions. const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); const int num = transformed_blob->num(); CHECK_EQ(channels, img_channels); CHECK_LE(height, img_height); CHECK_LE(width, img_width); CHECK_GE(num, 1); CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; const Dtype scale = param_.scale(); const bool do_mirror = param_.mirror() && Rand(2); const bool has_mean_file = param_.has_mean_file(); const bool has_mean_values = mean_values_.size() > 0; CHECK_GT(img_channels, 0); CHECK_GE(img_height, crop_size); CHECK_GE(img_width, crop_size); Dtype* mean = NULL; if (has_mean_file) { CHECK_EQ(img_channels, data_mean_.channels()); CHECK_EQ(img_height, data_mean_.height()); CHECK_EQ(img_width, data_mean_.width()); mean = data_mean_.mutable_cpu_data(); } if (has_mean_values) { CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) << "Specify either 1 mean_value or as many as channels: " << img_channels; if (img_channels > 1 && mean_values_.size() == 1) { // Replicate the mean_value for simplicity for (int c = 1; c < img_channels; ++c) { mean_values_.push_back(mean_values_[0]); } } } int h_off = 0; int w_off = 0; cv::Mat cv_cropped_img = cv_img; if (crop_size) { CHECK_EQ(crop_size, height); CHECK_EQ(crop_size, width); // We only do random crop when we do training. if (phase_ == TRAIN) { h_off = Rand(img_height - crop_size + 1); w_off = Rand(img_width - crop_size + 1); } else { h_off = (img_height - crop_size) / 2; w_off = (img_width - crop_size) / 2; } cv::Rect roi(w_off, h_off, crop_size, crop_size); cv_cropped_img = cv_img(roi); } else { CHECK_EQ(img_height, height); CHECK_EQ(img_width, width); } CHECK(cv_cropped_img.data); Dtype* transformed_data = transformed_blob->mutable_cpu_data(); int top_index; for (int h = 0; h < height; ++h) { const uchar* ptr = cv_cropped_img.ptr<uchar>(h); int img_index = 0; for (int w = 0; w < width; ++w) { for (int c = 0; c < img_channels; ++c) { if (do_mirror) { top_index = (c * height + h) * width + (width - 1 - w); } else { top_index = (c * height + h) * width + w; } // int top_index = (c * height + h) * width + w; Dtype pixel = static_cast<Dtype>(ptr[img_index++]); if (has_mean_file) { int mean_index = (c * img_height + h_off + h) * img_width + w_off + w; transformed_data[top_index] = (pixel - mean[mean_index]) * scale; } else { if (has_mean_values) { transformed_data[top_index] = (pixel - mean_values_[c]) * scale; } else { transformed_data[top_index] = pixel * scale; } } } } }}#endif // USE_OPENCVtemplate<typename Dtype>void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob) { const int crop_size = param_.crop_size(); const int input_num = input_blob->num(); const int input_channels = input_blob->channels(); const int input_height = input_blob->height(); const int input_width = input_blob->width(); if (transformed_blob->count() == 0) { // Initialize transformed_blob with the right shape. if (crop_size) { transformed_blob->Reshape(input_num, input_channels, crop_size, crop_size); } else { transformed_blob->Reshape(input_num, input_channels, input_height, input_width); } } const int num = transformed_blob->num(); const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); const int size = transformed_blob->count(); CHECK_LE(input_num, num); CHECK_EQ(input_channels, channels); CHECK_GE(input_height, height); CHECK_GE(input_width, width); const Dtype scale = param_.scale(); const bool do_mirror = param_.mirror() && Rand(2); const bool has_mean_file = param_.has_mean_file(); const bool has_mean_values = mean_values_.size() > 0; int h_off = 0; int w_off = 0; if (crop_size) { CHECK_EQ(crop_size, height); CHECK_EQ(crop_size, width); // We only do random crop when we do training. if (phase_ == TRAIN) { h_off = Rand(input_height - crop_size + 1); w_off = Rand(input_width - crop_size + 1); } else { h_off = (input_height - crop_size) / 2; w_off = (input_width - crop_size) / 2; } } else { CHECK_EQ(input_height, height); CHECK_EQ(input_width, width); } // 如果有均值文件则 Dtype* input_data = input_blob->mutable_cpu_data(); if (has_mean_file) { CHECK_EQ(input_channels, data_mean_.channels()); CHECK_EQ(input_height, data_mean_.height()); CHECK_EQ(input_width, data_mean_.width()); for (int n = 0; n < input_num; ++n) { int offset = input_blob->offset(n); /* template <typename Dtype> void caffe_sub(const int N, const Dtype* a, const Dtype* b, Dtype* y); math_function中定义的caffe_sub目的是矩阵相减input_data(以offset开始的矩阵) = input_data(以offset开始的矩阵) - data_mean_ */ caffe_sub(data_mean_.count(), input_data + offset, data_mean_.cpu_data(), input_data + offset); } } // 如果每个channel有均值则 if (has_mean_values) { CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) << "Specify either 1 mean_value or as many as channels: " << input_channels; if (mean_values_.size() == 1) { caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data); } else { for (int n = 0; n < input_num; ++n) { for (int c = 0; c < input_channels; ++c) { int offset = input_blob->offset(n, c); // 给nput_data[offset]地址开始的每一个元素加上一个-mean_values_[c] caffe_add_scalar(input_height * input_width, -(mean_values_[c]), input_data + offset); } } } } // 如果啥均值都没有则直接复制 Dtype* transformed_data = transformed_blob->mutable_cpu_data(); for (int n = 0; n < input_num; ++n) { int top_index_n = n * channels; int data_index_n = n * channels; for (int c = 0; c < channels; ++c) { int top_index_c = (top_index_n + c) * height; int data_index_c = (data_index_n + c) * input_height + h_off; for (int h = 0; h < height; ++h) { int top_index_h = (top_index_c + h) * width; int data_index_h = (data_index_c + h) * input_width + w_off; if (do_mirror) { int top_index_w = top_index_h + width - 1; for (int w = 0; w < width; ++w) { transformed_data[top_index_w-w] = input_data[data_index_h + w]; } } else { for (int w = 0; w < width; ++w) { transformed_data[top_index_h + w] = input_data[data_index_h + w]; } } } } } if (scale != Dtype(1)) { DLOG(INFO) << "Scale: " << scale; caffe_scal(size, scale, transformed_data); }}template<typename Dtype>vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) { if (datum.encoded()) {#ifdef USE_OPENCV // 如果使用OpenCV则可以用先转换为CVMat,然后在推断blob的形状 CHECK(!(param_.force_color() && param_.force_gray())) << "cannot set both force_color and force_gray"; cv::Mat cv_img; if (param_.force_color() || param_.force_gray()) { // If force_color then decode in color otherwise decode in gray. cv_img = DecodeDatumToCVMat(datum, param_.force_color()); } else { cv_img = DecodeDatumToCVMatNative(datum); } // InferBlobShape using the cv::image. return InferBlobShape(cv_img);#else LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";#endif // USE_OPENCV } // 否则直接粗暴地从datum里面获取形状的数据 const int crop_size = param_.crop_size(); const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); // Check dimensions. CHECK_GT(datum_channels, 0); CHECK_GE(datum_height, crop_size); CHECK_GE(datum_width, crop_size); // Build BlobShape. vector<int> shape(4); shape[0] = 1; shape[1] = datum_channels; shape[2] = (crop_size)? crop_size: datum_height; shape[3] = (crop_size)? crop_size: datum_width; return shape;}template<typename Dtype>vector<int> DataTransformer<Dtype>::InferBlobShape( const vector<Datum> & datum_vector) { const int num = datum_vector.size(); CHECK_GT(num, 0) << "There is no datum to in the vector"; // Use first datum in the vector to InferBlobShape. // 使用第一个来进行推断 vector<int> shape = InferBlobShape(datum_vector[0]); // Adjust num to the size of the vector. shape[0] = num; return shape;}#ifdef USE_OPENCV// 如果使用OpenCV// 使用CVMat中的信息来推断形状template<typename Dtype>vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) { const int crop_size = param_.crop_size(); const int img_channels = cv_img.channels(); const int img_height = cv_img.rows; const int img_width = cv_img.cols; // Check dimensions. CHECK_GT(img_channels, 0); CHECK_GE(img_height, crop_size); CHECK_GE(img_width, crop_size); // Build BlobShape. vector<int> shape(4); shape[0] = 1; shape[1] = img_channels; shape[2] = (crop_size)? crop_size: img_height; shape[3] = (crop_size)? crop_size: img_width; return shape;}template<typename Dtype>vector<int> DataTransformer<Dtype>::InferBlobShape( const vector<cv::Mat> & mat_vector) { const int num = mat_vector.size(); CHECK_GT(num, 0) << "There is no cv_img to in the vector"; // Use first cv_img in the vector to InferBlobShape. // 使用第一个来推断 vector<int> shape = InferBlobShape(mat_vector[0]); // Adjust num to the size of the vector. shape[0] = num; return shape;}#endif // USE_OPENCV// 初始化随机数种子template <typename Dtype>void DataTransformer<Dtype>::InitRand() { // 要么需要镜像要么训练阶段和需要crop同时满足的情况下才初始化随机数种子 const bool needs_rand = param_.mirror() || (phase_ == TRAIN && param_.crop_size()); if (needs_rand) { const unsigned int rng_seed = caffe_rng_rand();// 获得随机数种子(通过熵池或者时间生成种子) rng_.reset(new Caffe::RNG(rng_seed));//初始化随机数种子并实例化随机数生成器 } else { rng_.reset();//否则随机数生成器设置为空 }}// 产生从0到n的随机数template <typename Dtype>int DataTransformer<Dtype>::Rand(int n) { CHECK(rng_); CHECK_GT(n, 0); caffe::rng_t* rng = static_cast<caffe::rng_t*>(rng_->generator()); return ((*rng)() % n);}INSTANTIATE_CLASS(DataTransformer);/*初始化类的宏定义是这样的,前面有讲过,这里再给出来#define INSTANTIATE_CLASS(classname) \ char gInstantiationGuard##classname; \ template class classname<float>; \ template class classname<double>*/} // namespace caffe
三、与DataTransformer类相关类的介绍
(1)io的介绍
首先给出io中定义的各个函数的含义:
#ifndef CAFFE_UTIL_IO_H_#define CAFFE_UTIL_IO_H_#include <unistd.h>#include <string>#include "google/protobuf/message.h"#include "caffe/blob.hpp"#include "caffe/common.hpp"#include "caffe/proto/caffe.pb.h"namespace caffe {using ::google::protobuf::Message;// 内联函数,创建临时文件inline void MakeTempFilename(string* temp_filename) { temp_filename->clear(); *temp_filename = "/tmp/caffe_test.XXXXXX"; char* temp_filename_cstr = new char[temp_filename->size() + 1]; // NOLINT_NEXT_LINE(runtime/printf) strcpy(temp_filename_cstr, temp_filename->c_str()); int fd = mkstemp(temp_filename_cstr); CHECK_GE(fd, 0) << "Failed to open a temporary file at: " << *temp_filename; close(fd); *temp_filename = temp_filename_cstr; delete[] temp_filename_cstr;}// 内联函数,创建临时目录inline void MakeTempDir(string* temp_dirname) { temp_dirname->clear(); *temp_dirname = "/tmp/caffe_test.XXXXXX"; char* temp_dirname_cstr = new char[temp_dirname->size() + 1]; // NOLINT_NEXT_LINE(runtime/printf) strcpy(temp_dirname_cstr, temp_dirname->c_str()); char* mkdtemp_result = mkdtemp(temp_dirname_cstr); CHECK(mkdtemp_result != NULL) << "Failed to create a temporary directory at: " << *temp_dirname; *temp_dirname = temp_dirname_cstr; delete[] temp_dirname_cstr;}// 从txt读取proto的定义bool ReadProtoFromTextFile(const char* filename, Message* proto);// 从text读取proto的定义inline bool ReadProtoFromTextFile(const string& filename, Message* proto) { return ReadProtoFromTextFile(filename.c_str(), proto);}// 从text读取proto的定义,只是增加了检查而已inline void ReadProtoFromTextFileOrDie(const char* filename, Message* proto) { CHECK(ReadProtoFromTextFile(filename, proto));}// 从text读取proto的定义,只是增加了检查而已inline void ReadProtoFromTextFileOrDie(const string& filename, Message* proto) { ReadProtoFromTextFileOrDie(filename.c_str(), proto);}// 将proto写入到txt文件void WriteProtoToTextFile(const Message& proto, const char* filename);inline void WriteProtoToTextFile(const Message& proto, const string& filename) { WriteProtoToTextFile(proto, filename.c_str());}// 从bin读取proto的定义bool ReadProtoFromBinaryFile(const char* filename, Message* proto);// 从bin读取proto的定义inline bool ReadProtoFromBinaryFile(const string& filename, Message* proto) { return ReadProtoFromBinaryFile(filename.c_str(), proto);}// 从bin读取proto的定义,只是增加了检查而已inline void ReadProtoFromBinaryFileOrDie(const char* filename, Message* proto) { CHECK(ReadProtoFromBinaryFile(filename, proto));}// 从bin读取proto的定义,只是增加了检查而已inline void ReadProtoFromBinaryFileOrDie(const string& filename, Message* proto) { ReadProtoFromBinaryFileOrDie(filename.c_str(), proto);}// 将proto写入到bin文件void WriteProtoToBinaryFile(const Message& proto, const char* filename);// 内联函数,将proto写入到bin文件inline void WriteProtoToBinaryFile( const Message& proto, const string& filename) { WriteProtoToBinaryFile(proto, filename.c_str());}// 从文件读取数据到Datumbool ReadFileToDatum(const string& filename, const int label, Datum* datum);// 内联函数,从文件读取数据到Datuminline bool ReadFileToDatum(const string& filename, Datum* datum) { return ReadFileToDatum(filename, -1, datum);}// 从图像文件读取数据到Datumbool ReadImageToDatum(const string& filename, const int label, const int height, const int width, const bool is_color, const std::string & encoding, Datum* datum);// 内联函数,从图像文件(彩色还是黑白?)读取数据到Datum,指定图像大小inline bool ReadImageToDatum(const string& filename, const int label, const int height, const int width, const bool is_color, Datum* datum) { return ReadImageToDatum(filename, label, height, width, is_color, "", datum);}// 内联函数,从彩色图像文件读取数据到Datum,指定图像大小inline bool ReadImageToDatum(const string& filename, const int label, const int height, const int width, Datum* datum) { return ReadImageToDatum(filename, label, height, width, true, datum);}// 内联函数,从图像文件(彩色还是黑白?)读取数据到Datum,自动获取图像大小inline bool ReadImageToDatum(const string& filename, const int label, const bool is_color, Datum* datum) { return ReadImageToDatum(filename, label, 0, 0, is_color, datum);}// 内联函数,从彩色图像文件读取数据到Datum,自动获取图像大小inline bool ReadImageToDatum(const string& filename, const int label, Datum* datum) { return ReadImageToDatum(filename, label, 0, 0, true, datum);}// 内联函数,从彩色图像文件读取数据到Datum,自动获取图像大小,指定编码格式inline bool ReadImageToDatum(const string& filename, const int label, const std::string & encoding, Datum* datum) { return ReadImageToDatum(filename, label, 0, 0, true, encoding, datum);}// 对Datum进行解码bool DecodeDatumNative(Datum* datum);// 对彩色图像的Datum进行解码bool DecodeDatum(Datum* datum, bool is_color);#ifdef USE_OPENCV// 将图像读取到CVMat,指定图像大小,是否彩色cv::Mat ReadImageToCVMat(const string& filename, const int height, const int width, const bool is_color);// 将图像读取到CVMat,指定图像大小cv::Mat ReadImageToCVMat(const string& filename, const int height, const int width);// 将图像读取到CVMat,指定是否彩色cv::Mat ReadImageToCVMat(const string& filename, const bool is_color);// 将图像读取到CVMatcv::Mat ReadImageToCVMat(const string& filename);// 将Datum解码为为CVMatcv::Mat DecodeDatumToCVMatNative(const Datum& datum);// 将彩色图像的Datum解码为为CVMatcv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);// 将CVMat转换为Datumvoid CVMatToDatum(const cv::Mat& cv_img, Datum* datum);#endif // USE_OPENCV} // namespace caffe#endif // CAFFE_UTIL_IO_H_
接下来给出io中的具体的实现的注释
#include <fcntl.h>#include <google/protobuf/io/coded_stream.h>#include <google/protobuf/io/zero_copy_stream_impl.h>#include <google/protobuf/text_format.h>#include <opencv2/core/core.hpp>#ifdef USE_OPENCV#include <opencv2/highgui/highgui.hpp>#include <opencv2/highgui/highgui_c.h>#include <opencv2/imgproc/imgproc.hpp>#endif // USE_OPENCV#include <stdint.h>#include <algorithm>#include <fstream> // NOLINT(readability/streams)#include <string>#include <vector>#include "caffe/common.hpp"#include "caffe/proto/caffe.pb.h"#include "caffe/util/io.hpp"const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.namespace caffe {using google::protobuf::io::FileInputStream;using google::protobuf::io::FileOutputStream;using google::protobuf::io::ZeroCopyInputStream;using google::protobuf::io::CodedInputStream;using google::protobuf::io::ZeroCopyOutputStream;using google::protobuf::io::CodedOutputStream;using google::protobuf::Message;// 从文件读取Proto的txt文件bool ReadProtoFromTextFile(const char* filename, Message* proto) { int fd = open(filename, O_RDONLY); CHECK_NE(fd, -1) << "File not found: " << filename; FileInputStream* input = new FileInputStream(fd); // 注意如何使用protobuf去读取 bool success = google::protobuf::TextFormat::Parse(input, proto); delete input; close(fd); return success;}// 将proto写入到txt文件void WriteProtoToTextFile(const Message& proto, const char* filename) { int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); FileOutputStream* output = new FileOutputStream(fd); // 注意如何写入 CHECK(google::protobuf::TextFormat::Print(proto, output)); delete output; close(fd);}// 从bin读取proto的定义bool ReadProtoFromBinaryFile(const char* filename, Message* proto) { int fd = open(filename, O_RDONLY); CHECK_NE(fd, -1) << "File not found: " << filename; ZeroCopyInputStream* raw_input = new FileInputStream(fd); // 解码流com.google.protobuf.CodedInputStream CodedInputStream* coded_input = new CodedInputStream(raw_input); coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912); bool success = proto->ParseFromCodedStream(coded_input); delete coded_input; delete raw_input; close(fd); return success;}// 将proto写入到bin文件void WriteProtoToBinaryFile(const Message& proto, const char* filename) { fstream output(filename, ios::out | ios::trunc | ios::binary); CHECK(proto.SerializeToOstream(&output));}#ifdef USE_OPENCV// 将图像读取到CVMat,指定图像大小,是否彩色cv::Mat ReadImageToCVMat(const string& filename, const int height, const int width, const bool is_color) { cv::Mat cv_img; int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag); if (!cv_img_origin.data) { LOG(ERROR) << "Could not open or find file " << filename; return cv_img_origin; } if (height > 0 && width > 0) { cv::resize(cv_img_origin, cv_img, cv::Size(width, height)); } else { cv_img = cv_img_origin; } return cv_img;}cv::Mat ReadImageToCVMat(const string& filename, const int height, const int width) { return ReadImageToCVMat(filename, height, width, true);}cv::Mat ReadImageToCVMat(const string& filename, const bool is_color) { return ReadImageToCVMat(filename, 0, 0, is_color);}cv::Mat ReadImageToCVMat(const string& filename) { return ReadImageToCVMat(filename, 0, 0, true);}// Do the file extension and encoding match?// 看看是不是jpg还是jpeg的图像static bool matchExt(const std::string & fn, std::string en) { size_t p = fn.rfind('.'); std::string ext = p != fn.npos ? fn.substr(p) : fn; std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); std::transform(en.begin(), en.end(), en.begin(), ::tolower); if ( ext == en ) return true; if ( en == "jpg" && ext == "jpeg" ) return true; return false;}// 从图像文件读取数据到Datumbool ReadImageToDatum(const string& filename, const int label, const int height, const int width, const bool is_color, const std::string & encoding, Datum* datum) { cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color); if (cv_img.data) { if (encoding.size()) { if ( (cv_img.channels() == 3) == is_color && !height && !width && matchExt(filename, encoding) ) return ReadFileToDatum(filename, label, datum); std::vector<uchar> buf; // 对数据解码 cv::imencode("."+encoding, cv_img, buf); datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]), buf.size())); // 数据标签 datum->set_label(label); // 是否被编码 datum->set_encoded(true); return true; } CVMatToDatum(cv_img, datum); datum->set_label(label); return true; } else { return false; }}#endif // USE_OPENCV// 从文件读取数据到Datumbool ReadFileToDatum(const string& filename, const int label, Datum* datum) { std::streampos size; fstream file(filename.c_str(), ios::in|ios::binary|ios::ate); if (file.is_open()) { size = file.tellg(); std::string buffer(size, ' '); file.seekg(0, ios::beg); file.read(&buffer[0], size); file.close(); datum->set_data(buffer); datum->set_label(label); datum->set_encoded(true); return true; } else { return false; }}#ifdef USE_OPENCV// 直接编码数据的Datum到CVMatcv::Mat DecodeDatumToCVMatNative(const Datum& datum) { cv::Mat cv_img; CHECK(datum.encoded()) << "Datum not encoded"; const string& data = datum.data(); std::vector<char> vec_data(data.c_str(), data.c_str() + data.size()); cv_img = cv::imdecode(vec_data, -1);//flag=-1 if (!cv_img.data) { LOG(ERROR) << "Could not decode datum "; } return cv_img;}// 直接编码彩色或者非彩色Datum到CVMatcv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) { cv::Mat cv_img; CHECK(datum.encoded()) << "Datum not encoded"; const string& data = datum.data(); std::vector<char> vec_data(data.c_str(), data.c_str() + data.size()); int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); cv_img = cv::imdecode(vec_data, cv_read_flag);// flag为用户指定的 if (!cv_img.data) { LOG(ERROR) << "Could not decode datum "; } return cv_img;}// If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum// If Datum is not encoded will do nothingbool DecodeDatumNative(Datum* datum) { if (datum->encoded()) { cv::Mat cv_img = DecodeDatumToCVMatNative((*datum)); CVMatToDatum(cv_img, datum); return true; } else { return false; }}// 将Datum进行解码bool DecodeDatum(Datum* datum, bool is_color) { if (datum->encoded()) { cv::Mat cv_img = DecodeDatumToCVMat((*datum), is_color); CVMatToDatum(cv_img, datum); return true; } else { return false; }}// 将CVMat转换到Datumvoid CVMatToDatum(const cv::Mat& cv_img, Datum* datum) { CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; datum->set_channels(cv_img.channels()); datum->set_height(cv_img.rows); datum->set_width(cv_img.cols); datum->clear_data(); datum->clear_float_data(); datum->set_encoded(false); int datum_channels = datum->channels(); int datum_height = datum->height(); int datum_width = datum->width(); int datum_size = datum_channels * datum_height * datum_width; std::string buffer(datum_size, ' '); for (int h = 0; h < datum_height; ++h) { const uchar* ptr = cv_img.ptr<uchar>(h); int img_index = 0; for (int w = 0; w < datum_width; ++w) { for (int c = 0; c < datum_channels; ++c) { int datum_index = (c * datum_height + h) * datum_width + w; buffer[datum_index] = static_cast<char>(ptr[img_index++]); } } } datum->set_data(buffer);}#endif // USE_OPENCV} // namespace caffe
四、总结
总结起来就是,DataTransformer所作的工作实际上就是crop数据,让数据减去均值,以及缩放数据。
然后就是根据数据来推断形状。此外还介绍了io的内容,里面包含了创建临时文件临时目录操作,以及从txt文件以及bin文件读取proto数据或者写入proto的数据到txt或者bin文件。
参考:
[1]你可能需要了解关于cv::imencode和 cv::imdecode函数的flag的含义
http://docs.opencv.org/2.4/modules/highgui/doc/reading_and_writing_images_and_video.html
1 0
- caffe代码阅读4:DataTransformer以及io的实现细节-2016.3.16
- caffe代码阅读2:DataTransformer以及io的实现细节
- caffe代码阅读3:data_reader、internalthread以及blocking_queue的实现细节-2016.3.15
- caffe代码阅读1:blob的实现细节-2016.3.14
- caffe代码阅读2:common的实现细节-2016.3.14
- caffe代码阅读5:Layer的实现细节-2016.3.17
- caffe代码阅读6:Filler的实现细节-2016.3.18
- caffe代码阅读7:LayerRegistry的实现细节-2016.3.18
- caffe代码阅读9:SyncedMemory的实现细节-2016.3.28
- caffe代码阅读1:Layer的介绍与实现细节
- caffe代码阅读5: Data_layers的实现细节
- caffe代码阅读5:Layer的实现细节
- caffe代码阅读8: Data_layers的实现细节(各个数据读取层的实现细节) 2016.3.25-28
- caffe代码阅读10:Caffe中卷积的实现细节(涉及到BaseConvolutionLayer、ConvolutionLayer、im2col等)-2016.4.3
- caffe代码阅读4:LayerRegistry的介绍与实现
- caffe代码阅读3:Filler的实现
- caffe代码阅读7:Caffe中卷积的实现
- Caffe 源码阅读笔记 [数据读入和处理] DataReader和DataTransformer
- CXF Interceptor拦截器
- 励志好文
- 上下轮播控件TextSwitcher
- 蓝桥杯 历届试题 错误票据
- tomcat7+jdk的keytool生成证书 配置https
- caffe代码阅读4:DataTransformer以及io的实现细节-2016.3.16
- Android: 启动另外的APP及传递参数
- python abc II
- JAVA CAS原理深度分析
- Android 语音识别+语音搜索源码 Voice Search
- matlab事件仿真基础
- 31. UITableView的编辑模式
- 100-200之间的素数
- Ext的CheckboxSelectionModel默认选中