DeepLearning(基于caffe)实战项目(1)--mnist_convert函数分析
来源:互联网 发布:环境污染测试软件 编辑:程序博客网 时间:2024/06/05 19:06
搞了这么长时间DeepLearning,打算用回忆的方式,进行一下总结。我们知道,caffe只能识别leveldb或者lmdb格式的文件,所以一切从数据转换开始。若想自己写转换函数程序(matlab/Python),自然而然需要读懂caffe中examples里转换的函数。
下面是mnist_convert.cpp的程序:
/***********************************************************************************************************************************
TIPS:caffe为什么采用lmdb或者leveldb,而不是直接读取原始数据呢?
一方面,数据类型五花八门,种类繁多,不可能用一套代码实现所有类型的输入数据要求,转换为统一格式可以简化数据读取层的实现;另一方面,使用leveldb或者lmdb可以提高磁盘IO利用率。
/************************************************************************************************************************************
引用相应的文件和命名空间:
////该程序将mnist数据集转换为caffe需要的格式(lmdb)//用法:mnist_convert_data input_folder output_db_file#include <gflags/gflags.h> //gflags命令行参数解析的头文件#include <glog/logging.h> //记录程序日志的glog头文件#include <google/protobuf/text_format.h> //解析proto类型文件中,解析prototxt类型的头文件#if defined(USE_LEVELDB) && defined(USE_LMDB) #include <leveldb/db.h> //引入leveldb类型数据头文件#include <leveldb/write_batch.h> //引入leveldb类型数据写入头文件#include <lmdb.h>#endif#if defined(_MSC_VER)#include <direct.h>#define mkdir(X, Y) _mkdir(X)#endif#include <stdint.h>#include <sys/stat.h>#include <fstream> //NOLINT(readability/streams)#include <string>#include "boost/scoped_ptr.hpp"#include "caffe/proto/caffe.pb.h" //解析caffe中proto类型文件的头文件#include "caffe/util/db.hpp"#include "caffe/util/format.hpp"#if defined(USE_LEVELDB) && defined(USE_LMDB)using namespace caffe; //NOLINT(build/namespaces)using boost::scoped_ptr;using std::string;
定义backend(程序变量):
大端字节存储的二进制文件与小端字节存储的二进制文件转换:
/****************************************************************************************************************************************************************************************
TIPS:为何需要两种二进制文件转换?
大小端字节的计算机存储的二进制文件格式不同,大端计算机无法读取小端计算机存储的二进制文件(小端一样)所以需要两种文件的转换。
/*****************************************************************************************************************************************************************************************
uint32_t swap_endian(uint32_t val) { val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); return (val << 16) | (val >> 16);}
convert_dataset函数(核心代码):
void convert_dataset(const char* image_filename, const char* label_filename,const char* db_path, const string& db_backend) {//打开(二进制)文件std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);//CHECK用于检测文件是否正常打开CHECK(image_file) << "Unable to open file " << image_filename;CHECK(label_file) << "Unable to open file " << label_filename;//根据mnist图像结构,定义长、宽、样本数、标签数//uint32_t是自定义数据类型,unsigned int 32是指每个int32整数占用4个字节uint32_t magic;uint32_t num_items;uint32_t num_labels;uint32_t rows;uint32_t cols;//读取图片数据结构//image的维度为4(magic,num_items,width,height)//label的维度为2(magic,num_labels)image_file.read(reinterpret_cast<char*>(&magic), 4);magic = swap_endian(magic);CHECK_EQ(magic, 2051) << "Incorrect image file magic.";label_file.read(reinterpret_cast<char*>(&magic), 4);magic = swap_endian(magic);CHECK_EQ(magic, 2049) << "Incorrect label file magic.";image_file.read(reinterpret_cast<char*>(&num_items), 4);num_items = swap_endian(num_items);label_file.read(reinterpret_cast<char*>(&num_labels), 4);num_labels = swap_endian(num_labels);CHECK_EQ(num_items, num_labels);image_file.read(reinterpret_cast<char*>(&rows), 4);rows = swap_endian(rows);image_file.read(reinterpret_cast<char*>(&cols), 4);cols = swap_endian(cols);//定义lmdb和leveldb类的变量MDB_env *mdb_env; MDB_dbi mdb_dbi; MDB_val mdb_key, mdb_data; MDB_txn *mdb_txn; leveldb::DB* db;leveldb::Options options; options.error_if_exists = true;options.create_if_missing = true;options.write_buffer_size = 268435456; leveldb::WriteBatch* batch = NULL;//open the filesif (db_backend == "leveldb") { // leveldb LOG(INFO) << "Opening leveldb " << db_path; leveldb::Status status = leveldb::DB::Open(options, db_path, &db); CHECK(status.ok()) << "Failed to open leveldb " << db_path<< ". Is it already existing?"; batch = new leveldb::WriteBatch();//Storing to db char label; char* pixels = new char[rows * cols]; int count = 0; string value;//define the detum Datum datum; datum.set_channels(1); datum.set_height(rows); datum.set_width(cols); LOG(INFO) << "A total of " << num_items << " items."; LOG(INFO) << "Rows: " << rows << " Cols: " << cols;//read the files and assign to "datum" for (int item_id = 0; item_id < num_items; ++item_id) { image_file.read(pixels, rows * cols); label_file.read(&label, 1); datum.set_data(pixels, rows*cols); datum.set_label(label); string key_str = caffe::format_int(item_id, 8); datum.SerializeToString(&value); txn->Put(key_str, value);//write to the batch if (++count % 1000 == 0) { txn->Commit();} }//write the last batch if (count % 1000 != 0) { txn->Commit(); } LOG(INFO) << "Processed " << count << " files."; delete[] pixels; db->Close();}
main函数(主函数代码):
int main(int argc, char** argv) {#ifndef GFLAGS_GFLAGS_H_ namespace gflags = google;#endif FLAGS_alsologtostderr = 1; //获取--backend=${BACKEND}参数 gflags::SetUsageMessage("This script converts the MNIST dataset to\n" "the lmdb/leveldb format used by Caffe to load data.\n" "Usage:\n" "convert_mnist_data [FLAGS] input_image_file input_label_file " "output_db_file\n" "The MNIST dataset could be downloaded at\n" "http://yann.lecun.com/exdb/mnist/\n" "You should gunzip them after downloading," "or directly use data/mnist/get_mnist.sh\n"); gflags::ParseCommandLineFlags(&argc, &argv, true); const string& db_backend = FLAGS_backend; if (argc != 4) { gflags::ShowUsageWithFlagsRestrict(argv[0],"examples/mnist/convert_mnist_data"); } else { google::InitGoogleLogging(argv[0]); convert_dataset(argv[1], argv[2], argv[3], db_backend); } return 0;}#elseint main(int argc, char** argv) { LOG(FATAL) << "This example requires LevelDB and LMDB; " << "compile with USE_LEVELDB and USE_LMDB.";}
- DeepLearning(基于caffe)实战项目(1)--mnist_convert函数分析
- DeepLearning(基于caffe)实战项目(7)--从caffe结构里函数总结一览caffe
- DeepLearning(基于caffe)实战项目(8)--修改caffe源代码从添加loss(层)函数开始
- DeepLearning(基于caffe)实战项目(3)--我们关心的caffe输出
- DeepLearning(基于caffe)实战项目(2)--mnist(image转lmdb)
- DeepLearning(基于caffe)实战项目(5)--Matlab画学习(Loss)曲线
- DeepLearning(基于caffe)实战项目(4)--Matlab测试训练好的model
- DeepLearning(基于caffe)实战项目(6)--探索leNet模型的真谛
- DeepLearning(基于caffe)实战项目(9)--Python测试训练好的model
- DeepLearning(基于caffe)实战项目(10)--Python编写网络配置文件
- DeepLearning(基于caffe)优化策略(3)--调参篇
- DeepLearning(基于caffe)优化策略(1)--Normalization篇:BN、WN、LN
- DeepLearning(基于caffe)优化策略(2)--防拟合篇:Dropout
- [转]DeepLearning(基于caffe)优化策略--Normalization:BN、WN、LN
- 实战项目分析(一)
- 实战项目分析(二)
- caffe中loss函数代码分析--caffe学习(16)
- caffe源码学习--blob基本用法(基于《21天实战caffe》)
- 详解spl_autoload_register()函数
- pthread多线程编程详细解析----条件变量 pthread_cond_t
- 解决ORA-02069: global_names parameter must be set to TRUE for this operation 问题
- 机器学习中特征降维和特征选择的区别
- 人工智能再攻下一城——“翻译领域赶超95%人类”【智库2861】
- DeepLearning(基于caffe)实战项目(1)--mnist_convert函数分析
- MySQL 导出数据select into outfile用法
- Git常用命令
- 逆矩阵介绍及C++/OpenCV/Eigen的三种实现
- REST架构风格
- 如何在tomcat下配置二级域名
- 在IntelliJ IDEA构建Kotlin项目
- 我的博客
- Openssl编程获取X509证书的DNS