caffe(三):MNIST数据集可视化

来源:互联网 发布:郑州财经校园网络管理 编辑:程序博客网 时间:2024/05/31 19:47

前言

在手写字符识别任务中,需要将MNIST数据集打开,可视为png图片,然后重新组装新的测试集和验证集。

代码实现

//author: zhimazhimaheng//time: 20170719//E-mail:1439352516@qq.com#include<fstream>#include<iostream>#include"opencv2/core/core.hpp"#include"opencv2/highgui/highgui.hpp"#include"opencv2/imgproc/imgproc.hpp"using namespace std;using namespace cv;int ReverseInt(int i){    unsigned char ch1, ch2, ch3, ch4;    ch1=i & 255;    ch2=(i>>8)&255;    ch3=(i>>16)&255;    ch4=(i>>24)&255;    return ((int) ch1<<24)+((int)ch2<<16)+((int)ch3<<8)+ch4;}void read_Mnist(string filename, vector<Mat> &vec){    ifstream file(filename, ios::binary);    if(file.is_open())    {        int magic_number=0;        int number_of_images=0;        int n_rows=0;        int n_cols=0;        file.read((char*)&magic_number, sizeof(magic_number));        magic_number=ReverseInt(magic_number);        file.read((char*)&number_of_images,sizeof(number_of_images));        number_of_images=ReverseInt(number_of_images);        file.read((char*)&n_rows, sizeof(n_rows));        n_rows=ReverseInt(n_rows);        file.read((char*)&n_cols, sizeof(n_cols));        n_cols=ReverseInt(n_cols);        for(int i=0; i<number_of_images; i++)        {            Mat tp=Mat::zeros(n_rows, n_cols, CV_8UC1);            for(int r=0; r<n_rows; r++)            {                for(int c=0; c<n_cols; c++)                {                    unsigned char temp=0;                    file.read((char*) &temp, sizeof(temp));                    tp.at<uchar>(r,c)=(int)temp;                }            }            vec.push_back(tp);        }    }}//读取训练与测试标签  void read_Mnist_Label(string filename, vector<int> &vec)  {      ifstream file (filename, ios::binary);      if (file.is_open()) {          int magic_number = 0;          int number_of_images = 0;          int n_rows = 0;          int n_cols = 0;          file.read((char*) &magic_number, sizeof(magic_number));          magic_number = ReverseInt(magic_number);          file.read((char*) &number_of_images,sizeof(number_of_images));          number_of_images = ReverseInt(number_of_images);          for(int i = 0; i < number_of_images; ++i)          {              unsigned char temp = 0;              file.read((char*) &temp, sizeof(temp));              vec[i]= (int)temp;          }      }  }  string GetImageName(int number, int arr[]){    string str1, str2;    for(int i=0; i<10; i++)    {        if(number==i)        {            arr[i]++;            char ch1[10];            sprintf(ch1, "%d", arr[i]);            str1=std::string(ch1);            if(arr[i]<10)            {                str1="0000"+str1;            }            else if(arr[i]<100)            {                str1="000"+str1;            }            else if(arr[i]<1000)            {                str1="00"+str1;            }            else if(arr[i]<10000)            {                str1="0"+str1;            }            break;        }    }    char ch2[10];    sprintf(ch2, "%d", number);    str2=std::string(ch2);    str2=str2+"_"+str1;    return str2;}int main(){    //测试数据和测试标签      //读取测试数据 转换为Mat      string filename_test_images = "D:/Mycode/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte";      int number_of_test_images = 10000;   //测试数据10000个      vector<cv::Mat> vec_test_images;      read_Mnist(filename_test_images, vec_test_images);      //读取测试标签 转换为vector      string filename_test_labels = "D:/Mycode/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte";      vector<int> vec_test_labels(number_of_test_images);      read_Mnist_Label(filename_test_labels, vec_test_labels);      if (vec_test_images.size() != vec_test_labels.size()) {          std::cout<<"parse MNIST test file error"<<endl;        return -1;      }      //保存测试图像      int count_digits[10];      for (int i = 0; i < 10; i++)          count_digits[i] = 0;      string save_test_images_path = "D:/Mycode/MNIST/test_images/";   //保存路径      for (int i = 0; i < vec_test_images.size(); i++)       {          int number = vec_test_labels[i];          string image_name = GetImageName(number, count_digits);          image_name = save_test_images_path + image_name + ".png";        cv::imwrite(image_name, vec_test_images[i]);      }  //训练数据与训练标签      //read MNIST image into OpenCV Mat vector      string filename_train_images = "D:/Mycode/train-images-idx3-ubyte/train-images.idx3-ubyte";      int number_of_train_images = 60000;      vector<cv::Mat> vec_train_images;      read_Mnist(filename_train_images, vec_train_images);      //read MNIST label into int vector      string filename_train_labels = "D:/Mycode/train-labels-idx1-ubyte/train-labels.idx1-ubyte";      vector<int> vec_train_labels(number_of_train_images);      read_Mnist_Label(filename_train_labels, vec_train_labels);      if (vec_train_images.size() != vec_train_labels.size()) {          cout<<"parse MNIST train file error"<<endl;          return -1;      }      //save train images      for (int i = 0; i < 10; i++)          count_digits[i] = 0;      string save_train_images_path = "D:/Mycode/MNIST/train_images/"; //保存路径      for (int i = 0; i < vec_train_images.size(); i++) {          int number = vec_train_labels[i];          string image_name = GetImageName(number, count_digits);          image_name = save_train_images_path + image_name + ".png";        cv::imwrite(image_name, vec_train_images[i]);      }      return 1;      }

结果如下所示:
这里写图片描述

原创粉丝点击