DeepLearning(基于caffe)实战项目(1)--mnist_convert函数分析

来源:互联网 发布:环境污染测试软件 编辑:程序博客网 时间:2024/06/05 19:06

        搞了这么长时间DeepLearning,打算用回忆的方式,进行一下总结。我们知道,caffe只能识别leveldb或者lmdb格式的文件,所以一切从数据转换开始。若想自己写转换函数程序(matlab/Python),自然而然需要读懂caffe中examples里转换的函数。

下面是mnist_convert.cpp的程序:

/***********************************************************************************************************************************

TIPS:caffe为什么采用lmdb或者leveldb,而不是直接读取原始数据呢?

一方面,数据类型五花八门,种类繁多,不可能用一套代码实现所有类型的输入数据要求,转换为统一格式可以简化数据读取层的实现;另一方面,使用leveldb或者lmdb可以提高磁盘IO利用率。

/************************************************************************************************************************************

引用相应的文件和命名空间:

////该程序将mnist数据集转换为caffe需要的格式(lmdb)//用法:mnist_convert_data input_folder output_db_file#include <gflags/gflags.h>     //gflags命令行参数解析的头文件#include <glog/logging.h>      //记录程序日志的glog头文件#include <google/protobuf/text_format.h>      //解析proto类型文件中,解析prototxt类型的头文件#if defined(USE_LEVELDB) && defined(USE_LMDB)      #include <leveldb/db.h>         //引入leveldb类型数据头文件#include <leveldb/write_batch.h>     //引入leveldb类型数据写入头文件#include <lmdb.h>#endif#if defined(_MSC_VER)#include <direct.h>#define mkdir(X, Y) _mkdir(X)#endif#include <stdint.h>#include <sys/stat.h>#include <fstream>  //NOLINT(readability/streams)#include <string>#include "boost/scoped_ptr.hpp"#include "caffe/proto/caffe.pb.h"      //解析caffe中proto类型文件的头文件#include "caffe/util/db.hpp"#include "caffe/util/format.hpp"#if defined(USE_LEVELDB) && defined(USE_LMDB)using namespace caffe;  //NOLINT(build/namespaces)using boost::scoped_ptr;using std::string;

定义backend(程序变)

//在程序调用时,铜鼓--backend=${BACKEND}来给变量命名
DEFINE_string(backend, "lmdb", "The backend for storing the result");    //GFLAGS工具定义明星行选项backend,默认是lmdb

大端字节存储的二进制文件与小端字节存储的二进制文件转换

/****************************************************************************************************************************************************************************************

TIPS:为何需要两种二进制文件转换?

大小端字节的计算机存储的二进制文件格式不同,大端计算机无法读取小端计算机存储的二进制文件(小端一样)所以需要两种文件的转换。

/*****************************************************************************************************************************************************************************************

uint32_t swap_endian(uint32_t val) {    val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);    return (val << 16) | (val >> 16);}

convert_dataset函数(核心代码)

void convert_dataset(const char* image_filename, const char* label_filename,const char* db_path, const string& db_backend) {//打开(二进制)文件std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);//CHECK用于检测文件是否正常打开CHECK(image_file) << "Unable to open file " << image_filename;CHECK(label_file) << "Unable to open file " << label_filename;//根据mnist图像结构,定义长、宽、样本数、标签数//uint32_t是自定义数据类型,unsigned int 32是指每个int32整数占用4个字节uint32_t magic;uint32_t num_items;uint32_t num_labels;uint32_t rows;uint32_t cols;//读取图片数据结构//image的维度为4(magic,num_items,width,height)//label的维度为2(magic,num_labels)image_file.read(reinterpret_cast<char*>(&magic), 4);magic = swap_endian(magic);CHECK_EQ(magic, 2051) << "Incorrect image file magic.";label_file.read(reinterpret_cast<char*>(&magic), 4);magic = swap_endian(magic);CHECK_EQ(magic, 2049) << "Incorrect label file magic.";image_file.read(reinterpret_cast<char*>(&num_items), 4);num_items = swap_endian(num_items);label_file.read(reinterpret_cast<char*>(&num_labels), 4);num_labels = swap_endian(num_labels);CHECK_EQ(num_items, num_labels);image_file.read(reinterpret_cast<char*>(&rows), 4);rows = swap_endian(rows);image_file.read(reinterpret_cast<char*>(&cols), 4);cols = swap_endian(cols);//定义lmdb和leveldb类的变量MDB_env *mdb_env;  MDB_dbi mdb_dbi;  MDB_val mdb_key, mdb_data;  MDB_txn *mdb_txn;  leveldb::DB* db;leveldb::Options options;   options.error_if_exists = true;options.create_if_missing = true;options.write_buffer_size = 268435456; leveldb::WriteBatch* batch = NULL;//open the filesif (db_backend == "leveldb") {      // leveldb    LOG(INFO) << "Opening leveldb " << db_path;    leveldb::Status status = leveldb::DB::Open(options, db_path, &db);    CHECK(status.ok()) << "Failed to open leveldb " << db_path<< ". Is it already existing?";    batch = new leveldb::WriteBatch();//Storing to db     char label;    char* pixels = new char[rows * cols];    int count = 0;    string value;//define the detum    Datum datum;    datum.set_channels(1);    datum.set_height(rows);    datum.set_width(cols);    LOG(INFO) << "A total of " << num_items << " items.";    LOG(INFO) << "Rows: " << rows << " Cols: " << cols;//read the files and assign to "datum"    for (int item_id = 0; item_id < num_items; ++item_id)   {    image_file.read(pixels, rows * cols);    label_file.read(&label, 1);    datum.set_data(pixels, rows*cols);    datum.set_label(label);    string key_str = caffe::format_int(item_id, 8);    datum.SerializeToString(&value);    txn->Put(key_str, value);//write to the batch    if (++count % 1000 == 0) {      txn->Commit();}  }//write the last batch    if (count % 1000 != 0) {      txn->Commit();  }    LOG(INFO) << "Processed " << count << " files.";    delete[] pixels;    db->Close();}

main函数(主函数代码)

int main(int argc, char** argv) {#ifndef GFLAGS_GFLAGS_H_  namespace gflags = google;#endif  FLAGS_alsologtostderr = 1;      //获取--backend=${BACKEND}参数  gflags::SetUsageMessage("This script converts the MNIST dataset to\n"        "the lmdb/leveldb format used by Caffe to load data.\n"        "Usage:\n"        "convert_mnist_data [FLAGS] input_image_file input_label_file "        "output_db_file\n"        "The MNIST dataset could be downloaded at\n"        "http://yann.lecun.com/exdb/mnist/\n"        "You should gunzip them after downloading,"        "or directly use data/mnist/get_mnist.sh\n");  gflags::ParseCommandLineFlags(&argc, &argv, true);  const string& db_backend = FLAGS_backend;  if (argc != 4) {    gflags::ShowUsageWithFlagsRestrict(argv[0],"examples/mnist/convert_mnist_data");  } else {    google::InitGoogleLogging(argv[0]);    convert_dataset(argv[1], argv[2], argv[3], db_backend);  }  return 0;}#elseint main(int argc, char** argv) {  LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB.";}

阅读全文
1 0
原创粉丝点击