使用Python将图像数据写入到MNIST格式以方便各个DL框架之间的训练

来源:互联网 发布:英雄联盟录制软件 编辑:程序博客网 时间:2024/05/19 13:26
最近使用各种不同的深度学习框架测试他们之间的性能和好坏,由于每个框架上手时间都比较短,没有足够的时间了解他们载入自定义格式数据的方法。由于深度学习界的Hello World 几乎就是使用CNN 训练 Mnist 手写数字。所以各个框架都几乎支持可以直接读取Mnist文件。所以把训练直接写入到MNIST格式,是一种很好的偷懒方法。不过如果不修改数据接口只能读取灰度图像。

我们先来看一下MNIST的数据格式。Lecun爸爸的主页就有。

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):[offset] [type]          [value]          [description] 0000     32 bit integer  0x00000801(2049) magic number (MSB first) 0004     32 bit integer  60000            number of items 0008     unsigned byte   ??               label 0009     unsigned byte   ??               label ........ xxxx     unsigned byte   ??               labelThe labels values are 0 to 9.TRAINING SET IMAGE FILE (train-images-idx3-ubyte):[offset] [type]          [value]          [description] 0000     32 bit integer  0x00000803(2051) magic number 0004     32 bit integer  60000            number of images 0008     32 bit integer  28               number of rows 0012     32 bit integer  28               number of columns 0016     unsigned byte   ??               pixel 0017     unsigned byte   ??               pixel ........ xxxx     unsigned byte   ??               pixel

Label File

  • 先是一个32位的整形 表示的是Magic Number,这是用来标示文件格式的用的。一般默认不变。2049
  • 第二是图片的的数量
  • 接下去就是一次排列图片的标示Label。
  • -

Image File

  • 也是Magic Number。同上。保持不变2051.
  • 图片的数量
  • 图片的高
  • 图片的宽
  • 图片的像素点[灰度 256位]。

实现过程

这种在C++/C 上实现似乎很简单。在Python下可以用Struct来处理二进制。可以看一下这篇博文

def writeMnist(data,rows,cols, path_images = "imt_mnist_training_set.data",path_labels="imt_mnist_training_labels.data"):"""@输入格式@data [imgs(浮点/字节),labels]@rows 高度@cols 宽度@path_images 训练图像集文件路径@path_labels label文件路径"""    _set,_labels = data;    model = 0;    print _set[0].dtype;    if(len(_set[0])>0 and type(_set[0]) == float):        model = 1;    magic_nums_trainning = 2051;    magic_nums_labels = 2049;    num_training = len(_set);    header_images = [magic_nums_trainning,num_training,rows,cols];    header_labels = [magic_nums_labels,num_training];    len_img = rows*cols;    header_images_format = '>IIII';    header_labels_format = '>II';    len_img_format = '>'+str(len_img)+'B'    buffer_training_set = create_string_buffer(4*4 + len_img*num_training);    buffer_training_labels = create_string_buffer(2*4  +  num_training);    offset = 0 ;    struct.pack_into(header_images_format,buffer_training_set,offset,*header_images);    offset += struct.calcsize(header_images_format);    print(len_img_format)    for i in range(num_training):        if(model == 1):            byte_type = np.array(_set[i])*255            byte_type = byte_type.astype(np.uint8).ravel();        else:            byte_type = _set[i].ravel();        struct.pack_into(len_img_format,buffer_training_set,offset,*byte_type);        offset += struct.calcsize(len_img_format);    file_training_set = open(path_images,'wb');    file_training_set.write(buffer_training_set);    offset =  0 ;    struct.pack_into(header_labels_format,buffer_training_labels,offset,*header_labels);    offset += struct.calcsize(header_labels_format);    for i in range(num_training):        struct.pack_into(">B",buffer_training_labels,offset,_labels[i]);        offset += struct.calcsize(">B")    file_training_label = open(path_labels,'wb');    file_training_label.write(buffer_training_labels);
0 0
原创粉丝点击