caffe convert_image 初略解析

来源:互联网 发布:程序员接活的网站 编辑:程序博客网 时间:2024/05/17 04:12

需要的基础知识:OpenCV(建议去看官网的图文教程), LevelDB(http://dblab.cs.toronto.edu/courses/443/2014/tutorials/leveldb.html 这个是我学习的教程)


今天在看caffe的代码,发现所用到的的数据都是leveldb的格式,而如果我们要是有形如imagenet的图片和标签的数据的话,就需要将他们给转化成leveldb的格式,caffe的代码中给了例子,在create_imagenet.sh 中,而这个shell文件主要就是调用了build/tools/convert_imageset.bin

举个例子



其TRAIN_DATA_ROOT,就是图片的路径, $DATA/train.txt,里面存放的是图片的名字,和图片的label(注意,这里的label应当从0开始编号,而且要连续,例如有20个类就是0到19,而不要0-10,12-20,这样会出错的), ilsvrc12_train_leveldb 输出名, 1(乱序读取, 0 顺序读取), 后面是图片归一化的大小,应为整数,为0表示不缩放

好了我们进代码来看

// This program converts a set of images to a leveldb by storing them as Datum// proto buffers.// Usage://   convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME RANDOM_SHUFFLE[0 or 1]//                     [resize_height] [resize_width]// where ROOTFOLDER is the root folder that holds all the images, and LISTFILE// should be a list of files as well as their labels, in the format as//   subfolder1/file1.JPEG 7//   ....// if RANDOM_SHUFFLE is 1, a random shuffle will be carried out before we// process the file lines.// Optional flag -g indicates the images should be read as// single-channel grayscale. If omitted, grayscale images will be// converted to color.#include <glog/logging.h>#include <leveldb/db.h>#include <leveldb/write_batch.h>#include <lmdb.h>#include <sys/stat.h>#include <algorithm>#include <fstream>  // NOLINT(readability/streams)#include <string>#include <utility>#include <vector>#include "caffe/proto/caffe.pb.h"#include "caffe/util/io.hpp"#include "caffe/util/rng.hpp"using namespace caffe;  // NOLINT(build/namespaces)using std::pair;using std::string;int main(int argc, char** argv) {  ::google::InitGoogleLogging(argv[0]);  if (argc < 4 || argc > 9) {    printf("Convert a set of images to the leveldb format used\n"        "as input for Caffe.\n"        "Usage:\n"        "    convert_imageset [-g] ROOTFOLDER/ LISTFILE DB_NAME"        " RANDOM_SHUFFLE_DATA[0 or 1] DB_BACKEND[leveldb or lmdb]"        " [resize_height] [resize_width]\n"        "The ImageNet dataset for the training demo is at\n"        "    http://www.image-net.org/download-images\n");    return 1;  }  // Test whether argv[1] == "-g"  bool is_color= !(string("-g") == string(argv[1]));  int  arg_offset = (is_color ? 0 : 1);  std::ifstream infile(argv[arg_offset+2]);  std::vector<std::pair<string, int> > lines;  string filename;  int label;
</pre><pre name="code" class="cpp">  //读入图片名和标签  while (infile >> filename >> label) {    lines.push_back(std::make_pair(filename, label));  }  if (argc >= (arg_offset+5) && argv[arg_offset+4][0] == '1') {    // randomly shuffle data
    //乱序
    LOG(INFO) << "Shuffling data";    shuffle(lines.begin(), lines.end());  }  LOG(INFO) << "A total of " << lines.size() << " images.";  string db_backend = "leveldb";  if (argc >= (arg_offset+6)) {    db_backend = string(argv[arg_offset+5]);    if (!(db_backend == "leveldb") && !(db_backend == "lmdb")) {      LOG(FATAL) << "Unknown db backend " << db_backend;    }  }  int resize_height = 0;  int resize_width = 0;  if (argc >= (arg_offset+7)) {    resize_height = atoi(argv[arg_offset+6]);  }  if (argc >= (arg_offset+8)) {    resize_width = atoi(argv[arg_offset+7]);  }  // Open new db  // lmdb  MDB_env *mdb_env;  MDB_dbi mdb_dbi;  MDB_val mdb_key, mdb_data;  MDB_txn *mdb_txn;  // leveldb  leveldb::DB* db;  leveldb::Options options;  options.error_if_exists = true;  options.create_if_missing = true;  options.write_buffer_size = 268435456;  leveldb::WriteBatch* batch = NULL;  // Open db  if (db_backend == "leveldb") {  // leveldb    LOG(INFO) << "Opening leveldb " << argv[arg_offset+3];    leveldb::Status status = leveldb::DB::Open(        options, argv[arg_offset+3], &db);    CHECK(status.ok()) << "Failed to open leveldb " << argv[arg_offset+3];    batch = new leveldb::WriteBatch();  } else if (db_backend == "lmdb") {  // lmdb    LOG(INFO) << "Opening lmdb " << argv[arg_offset+3];    CHECK_EQ(mkdir(argv[arg_offset+3], 0744), 0)        << "mkdir " << argv[arg_offset+3] << "failed";    CHECK_EQ(mdb_env_create(&mdb_env), MDB_SUCCESS) << "mdb_env_create failed";    CHECK_EQ(mdb_env_set_mapsize(mdb_env, 1099511627776), MDB_SUCCESS)  // 1TB        << "mdb_env_set_mapsize failed";    CHECK_EQ(mdb_env_open(mdb_env, argv[3], 0, 0664), MDB_SUCCESS)        << "mdb_env_open failed";    CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)        << "mdb_txn_begin failed";    CHECK_EQ(mdb_open(mdb_txn, NULL, 0, &mdb_dbi), MDB_SUCCESS)        << "mdb_open failed";  } else {    LOG(FATAL) << "Unknown db backend " << db_backend;  }  // Storing to db  string root_folder(argv[arg_offset+1]);  Datum datum;  int count = 0;  const int kMaxKeyLength = 256;  char key_cstr[kMaxKeyLength];  int data_size;  bool data_size_initialized = false;  for (int line_id = 0; line_id < lines.size(); ++line_id) {
    //读图片进入datum
    if (!ReadImageToDatum(root_folder + lines[line_id].first,        lines[line_id].second, resize_height, resize_width, is_color, &datum)) {      continue;    }    if (!data_size_initialized) {      data_size = datum.channels() * datum.height() * datum.width();      data_size_initialized = true;    } else {      const string& data = datum.data();      CHECK_EQ(data.size(), data_size) << "Incorrect data field size "          << data.size();    }    // sequential    snprintf(key_cstr, kMaxKeyLength, "%08d_%s", line_id,        lines[line_id].first.c_str());    string value;    datum.SerializeToString(&value);    string keystr(key_cstr);    // Put in db    if (db_backend == "leveldb") {  // leveldb      batch->Put(keystr, value);    } else if (db_backend == "lmdb") {  // lmdb      mdb_data.mv_size = value.size();      mdb_data.mv_data = reinterpret_cast<void*>(&value[0]);      mdb_key.mv_size = keystr.size();      mdb_key.mv_data = reinterpret_cast<void*>(&keystr[0]);      CHECK_EQ(mdb_put(mdb_txn, mdb_dbi, &mdb_key, &mdb_data, 0), MDB_SUCCESS)          << "mdb_put failed";    } else {      LOG(FATAL) << "Unknown db backend " << db_backend;    }    if (++count % 1000 == 0) {      // Commit txn      if (db_backend == "leveldb") {  // leveldb        db->Write(leveldb::WriteOptions(), batch);        delete batch;        batch = new leveldb::WriteBatch();      } else if (db_backend == "lmdb") {  // lmdb        CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS)            << "mdb_txn_commit failed";        CHECK_EQ(mdb_txn_begin(mdb_env, NULL, 0, &mdb_txn), MDB_SUCCESS)            << "mdb_txn_begin failed";      } else {        LOG(FATAL) << "Unknown db backend " << db_backend;      }      LOG(ERROR) << "Processed " << count << " files.";    }  }  // write the last batch  if (count % 1000 != 0) {    if (db_backend == "leveldb") {  // leveldb      db->Write(leveldb::WriteOptions(), batch);      delete batch;      delete db;    } else if (db_backend == "lmdb") {  // lmdb      CHECK_EQ(mdb_txn_commit(mdb_txn), MDB_SUCCESS) << "mdb_txn_commit failed";      mdb_close(mdb_env, mdb_dbi);      mdb_env_close(mdb_env);    } else {      LOG(FATAL) << "Unknown db backend " << db_backend;    }    LOG(ERROR) << "Processed " << count << " files.";  }  return 0;}





这里我们主要看一下ReadImageToDatum,这个函数,如果大家很熟悉opencv的话,下面的这部分代码相当好懂

bool ReadImageToDatum(const string& filename, const int label,    const int height, const int width, const bool is_color, Datum* datum) {  cv::Mat cv_img;  int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :    CV_LOAD_IMAGE_GRAYSCALE);
  //读图片  if (height > 0 && width > 0) {    cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag);    cv::resize(cv_img_origin, cv_img, cv::Size(width, height));  } else {    cv_img = cv::imread(filename, cv_read_flag);  }  if (!cv_img.data) {    LOG(ERROR) << "Could not open or find file " << filename;    return false;  }
</pre><pre name="code" class="cpp">  //设置参数  int num_channels = (is_color ? 3 : 1);  datum->set_channels(num_channels);  datum->set_height(cv_img.rows);  datum->set_width(cv_img.cols);  datum->set_label(label);  datum->clear_data();  datum->clear_float_data();  string* datum_string = datum->mutable_data();
 <span style="font-family: Arial, Helvetica, sans-serif;">  //数据写入datum中</span>
  if (is_color) {    for (int c = 0; c < num_channels; ++c) {      for (int h = 0; h < cv_img.rows; ++h) {        for (int w = 0; w < cv_img.cols; ++w) {          datum_string->push_back(            static_cast<char>(cv_img.at<cv::Vec3b>(h, w)[c]));        }      }    }  } else {  // Faster than repeatedly testing is_color for each pixel w/i loop    for (int h = 0; h < cv_img.rows; ++h) {      for (int w = 0; w < cv_img.cols; ++w) {        datum_string->push_back(          static_cast<char>(cv_img.at<uchar>(h, w)));        }      }  }  return true;}



0 0