caffe特征提取/C++数据格式转换

来源:互联网 发布:最短域名邮箱 编辑:程序博客网 时间:2024/06/07 06:23
Caffe生成的数据分为2种格式:Lmdb 和 Leveldb
  • 它们都是键/值对(Key/Value Pair)嵌入式数据库管理系统编程库。
  • 虽然lmdb的内存消耗是leveldb的1.1倍,但是lmdb的速度比leveldb快10%至15%,更重要的是lmdb允许多种训练模型同时读取同一组数据集。
  • 因此lmdb取代了leveldb成为Caffe默认的数据集生成格式。

create_babyface.sh调用的convertData的源代码如下:


[cpp] view plain copy
 在CODE上查看代码片派生到我的代码片
  1. #include<sys/types.h>  
  2. #include<sys/stat.h>  
  3. #include<dirent.h>  
  4. #include <stdio.h>  
  5. #include<string.h>  
  6.   
  7. #include <fstream>  // NOLINT(readability/streams)  
  8. #include <string>  
  9. #include <vector>  
  10.   
  11. #include "boost/scoped_ptr.hpp"  
  12. #include "glog/logging.h"  
  13. #include "google/protobuf/text_format.h"  
  14. #include "stdint.h"  
  15.   
  16. #include "caffe/proto/caffe.pb.h"  
  17. #include "caffe/util/db.hpp"  
  18.   
  19. #include <opencv/cv.h>  
  20. #include <opencv/highgui.h>  
  21.   
  22. using caffe::Datum;  
  23. using boost::scoped_ptr;  
  24. using std::string;  
  25. namespace db = caffe::db;  
  26. using namespace std;  
  27.   
  28. const int kCIFARSize = 32;  
  29. const int kCIFARChannelBytes = 1024;  
  30. const int kCIFARImageNBytes = 3072;  
  31. const int kCIFARBatchSize = 1000;//1000 for a batch!  
  32. const int kCIFARTrainBatches = 5;  
  33.   
  34. void read_image(std::ifstream* file, int* label, char* buffer) {  
  35.     char label_char;  
  36.     file->read(&label_char, 1);  
  37.     *label = label_char;  
  38.     file->read(buffer, kCIFARImageNBytes);  
  39.     return;  
  40. }  
  41.   
  42. //Read IPLimage to the buffer  
  43. void read_image(  
  44.         IplImage* out, char* buffer,  
  45.         char* RC, char* GC, char* BC)  
  46. {  
  47.     int x,y;  
  48.     int idx =0;  
  49.     for(y = 0; y<out->height; y++){  
  50.         char *ptr= out->imageData + y * out->widthStep;  
  51.         for( x = 0;x< out->width;x++){  
  52.             idx =y*out->height + x;  
  53.             BC[idx]= ptr[3*x];  
  54.             GC[idx]= ptr[3*x+1];  
  55.             RC[idx]=  ptr[3*x+2]; //这样就可以添加自己的操作,这里我使三通道颜色一样,就彩色图转黑白图了  
  56.         }  
  57.     }  
  58.     memcpy( buffer ,RC, kCIFARChannelBytes*sizeof(char) );  
  59.     memcpy( buffer+ kCIFARChannelBytes*sizeof(char) ,      GC,kCIFARChannelBytes*sizeof(char) );  
  60.     memcpy( buffer+ kCIFARChannelBytes*sizeof(char)  *2, BC,kCIFARChannelBytes*sizeof(char) );  
  61.     return;  
  62. }  
  63.   
  64. //Travel the folder and load the filelist!  
  65. //使用linux dirent遍历目录  
  66.  int traveldir(char* path ,int depth, vector<string > &FileList)  
  67. {  
  68.     DIR* d;// a  
  69.     struct dirent *file; struct stat  sb;  
  70.   
  71.     if( !(d=opendir(path ) ) ){  
  72.         printf("Read path %s error,wishchin! ", path);  
  73.         return -1;  
  74.     }  
  75.   
  76.     while( (file= readdir(d ) ) != NULL ) {  
  77.         if(0== strncmp(file->d_name,  ".", 1 ) ) continue;  
  78.         char filename[256];  
  79.         strcpy( filename , file->d_name );  
  80.   
  81.         string  Sfilename(filename);string  Spath(path);  
  82.         Spath.append(Sfilename);  
  83.         FileList.push_back(Spath);  
  84.     }  
  85.   
  86.     if( stat(file->d_name,  &sb)>=0 && S_ISDIR(sb.st_mode) && depth <=4 )  
  87.         traveldir(file->d_name,depth+1,FileList);  
  88.   
  89.     closedir(d);  
  90.     return 1;  
  91. }  
  92.   
  93. // convert the data to the lmdb format !  
  94. void convert_dataset(  
  95.         const string& input_folder,  
  96.         const string& output_folder,  
  97.         const string& db_type) {  
  98.   
  99.     scoped_ptr<db::DB> train_db(db::GetDB(db_type));  
  100.     train_db->Open(output_folder + "/babyface_train_" + db_type, db::NEW);  
  101.     scoped_ptr<db::Transaction> txn(train_db->NewTransaction());  
  102.   
  103.     char* path=new char[256];  
  104.     int depth=2;  
  105.     vector<string > FileList(0);  
  106.   
  107.     // Data buffer  
  108.     int label;  
  109.     IplImage* ImageS;  
  110.     char str_buffer[kCIFARImageNBytes];  
  111.     char* RC=new char[kCIFARChannelBytes];  
  112.     char* GC=new char[kCIFARChannelBytes];  
  113.     char* BC=new char[kCIFARChannelBytes];  
  114.     Datum datum;  
  115.     datum.set_channels(3);  
  116.     datum.set_height(kCIFARSize);  
  117.     datum.set_width(kCIFARSize);  
  118.   
  119.     //"Writing Training data"//载入训练数据  
  120.     LOG(INFO) << "Writing Training data";  
  121.   
  122.     strcpy(path,( input_folder+(string)("train1") ).c_str() );  
  123.     traveldir( path , depth, FileList);  
  124.   
  125.     for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {  
  126.         // Open files  
  127.         LOG(INFO) << "Training Batch " << fileid + 1;  
  128.         snprintf(str_buffer, kCIFARImageNBytes, "/data_batch_%d.bin", fileid + 1);  
  129.         //CHECK(data_file) << "Unable to open train file #" << fileid + 1;  
  130.   
  131.         label=1;//The Batch has 10000 pics!  
  132.         for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {  
  133.             ImageS =cvLoadImage( (FileList[ fileid*kCIFARTrainBatches + itemid] ).c_str() );  
  134.             read_image( ImageS, str_buffer, RC,  GC, BC);  
  135.   
  136.             datum.set_label(label);//datum.set_label(label);  
  137.             datum.set_data(str_buffer, kCIFARImageNBytes);  
  138.   
  139.             int length = snprintf(str_buffer,  kCIFARImageNBytes,  
  140.                                   "%05d", fileid * kCIFARBatchSize + itemid);  
  141.             string out;  
  142.             CHECK(datum.SerializeToString( &out)  )  ;  
  143.             txn->Put(string(str_buffer, length),  out);//The main sentence ,put data to the txn!  
  144.         }  
  145.     }  
  146.   
  147.     strcpy(path,( input_folder+(string)("train0") ).c_str() );  
  148.     traveldir( path , depth, FileList);  
  149.     for (int fileid = 0; fileid < kCIFARTrainBatches; ++fileid) {  
  150.         LOG(INFO) << "Training Batch " << fileid + 1;  
  151.         snprintf(str_buffer, kCIFARImageNBytes, "/data_batch_%d.bin", fileid + 1);  
  152.         //CHECK(data_file) << "Unable to open train file #" << fileid + 1;  
  153.   
  154.         label=0;//The Batch has 10000 pics!  
  155.         for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {  
  156.             ImageS =cvLoadImage( (FileList[ fileid*kCIFARTrainBatches + itemid] ).c_str() );  
  157.             read_image( ImageS, str_buffer, RC,  GC, BC);  
  158.   
  159.             datum.set_label(label);//datum.set_label(label);  
  160.             datum.set_data(str_buffer, kCIFARImageNBytes);  
  161.   
  162.             int length = snprintf(str_buffer,  kCIFARImageNBytes,  
  163.                                   "%05d", fileid * kCIFARBatchSize + itemid);  
  164.             string out;  
  165.             CHECK(datum.SerializeToString( &out)  )  ;  
  166.             txn->Put(string(str_buffer, length),  out);//The main sentence ,put data to the txn!  
  167.         }  
  168.     }  
  169.   
  170.     txn->Commit();  
  171.     train_db->Close();  
  172.   
  173.     //写入测试数据!  
  174.     LOG(INFO) << "Writing Testing data";  
  175.     scoped_ptr<db::DB> test_db(db::GetDB(db_type));  
  176.     test_db->Open(output_folder + "/babyface_test_" + db_type, db::NEW);  
  177.     txn.reset(test_db->NewTransaction());  
  178.   
  179.     strcpy(path,( input_folder+(string)("test1") ).c_str() );  
  180.     traveldir( path , depth, FileList);  
  181.     for (int fileid = 0; fileid < 2; ++fileid) {  
  182.         LOG(INFO) << "Training Batch " << fileid + 1;  
  183.         snprintf(str_buffer, kCIFARImageNBytes, "/data_batch_%d.bin", fileid + 1);  
  184.   
  185.         label=1;//The Batch has 10000 pics!  
  186.         for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {  
  187.             ImageS =cvLoadImage( (FileList[ fileid*2 + itemid] ).c_str() );  
  188.             read_image( ImageS, str_buffer, RC,  GC, BC);  
  189.   
  190.             datum.set_label(label);//datum.set_label(label);  
  191.             datum.set_data(str_buffer, kCIFARImageNBytes);  
  192.   
  193.             int length = snprintf(str_buffer,  kCIFARImageNBytes,  
  194.                                   "%05d", fileid * kCIFARBatchSize + itemid);  
  195.             string out;  
  196.             CHECK(datum.SerializeToString( &out)  )  ;  
  197.             txn->Put(string(str_buffer, length),  out);//The main sentence ,put data to the txn!  
  198.         }  
  199.     }  
  200.   
  201.     strcpy(path,( input_folder+(string)("test0") ).c_str() );  
  202.     traveldir( path , depth, FileList);  
  203.     for (int fileid = 0; fileid < 2; ++fileid) {  
  204.   
  205.         LOG(INFO) << "Training Batch " << fileid + 1;  
  206.         snprintf(str_buffer, kCIFARImageNBytes, "/data_batch_%d.bin", fileid + 1);  
  207.   
  208.         label=0;//The Batch has 10000 pics!  
  209.         for (int itemid = 0; itemid < kCIFARBatchSize; ++itemid) {  
  210.             ImageS =cvLoadImage( (FileList[ fileid*2 + itemid] ).c_str() );  
  211.             read_image( ImageS, str_buffer, RC,  GC, BC);  
  212.   
  213.             datum.set_label(label);//datum.set_label(label);  
  214.             datum.set_data(str_buffer, kCIFARImageNBytes);  
  215.   
  216.             int length = snprintf(str_buffer,  kCIFARImageNBytes,  
  217.                                   "%05d", fileid * kCIFARBatchSize + itemid);  
  218.             string out;  
  219.             CHECK(datum.SerializeToString( &out)  )  ;  
  220.             txn->Put(string(str_buffer, length),  out);//The main sentence ,put data to the txn!  
  221.         }  
  222.     }  
  223.   
  224.     txn->Commit();  
  225.     test_db->Close();  
  226.   
  227.     cvReleaseImage(&ImageS);  
  228.     delete [] RC;delete [] GC;delete [] BC;  
  229. }  
  230.   
  231. int main(int argc, char** argv) {  
  232.     if (argc != 4) {  
  233.         printf("This script converts the CIFAR dataset to the leveldb format used\n"  
  234.                "by caffe to perform classification.\n"  
  235.                "Usage:\n"  
  236.                "    convert_cifar_data input_folder output_folder db_type\n"  
  237.                "Where the input folder should contain the binary batch files.\n"  
  238.                "The CIFAR dataset could be downloaded at\n"  
  239.                "    http://www.cs.toronto.edu/~kriz/cifar.html\n"  
  240.                "You should gunzip them after downloading.\n");  
  241.     } else {  
  242.         google::InitGoogleLogging(argv[0]);  
  243.         convert_dataset(string(argv[1]), string(argv[2]), string(argv[3]));  
  244.     }  
  245.     return 0;  
  246. }  

        后记:目的是载入32×32的三通道图像,直接输入3072维的char向量进行训练,至于怎样训练网络,还得仔细查看一下。

      后记:代码出现 coredump 问题,利用 gcc path/...bin  -o coredemo -g ,出现caffe.pb.h 包含丢失现象,why??? 

0 0