TensorFlow官方教程学习笔记(四)——MNIST数据集的读取

来源:互联网 发布:navicat数据库链接不了 编辑:程序博客网 时间:2024/05/01 04:40

在TensorFlow的源码中,MNIST数据集的读取操作在contrib\learn\python\learn\datasets\data\mnist.py中。


主要看第189行的read_data_sets函数:

def read_data_sets(train_dir,                   fake_data=False,                   one_hot=False,                   dtype=dtypes.float32,                   reshape=True,                   validation_size=5000):

train_dir为数据集在文件夹的位置,在这里为tensorflow\examples\tutorials\mnist\MNIST_data;

在官方教程中提到fake_data标记是用于单元测试的,读者可以不必理会;

one_hot为one_hot编码,即独热码,作用是将状态值编码成状态向量,例如,数字状态共有0~9这10种,对于数字7,将它进行one_hot编码后为[0 0 0 0 0 0 0 1 0 0],这样使得状态对于计算机来说更加明确,对于矩阵操作也更加高效。

dtype的作用是将图像像素点的灰度值从[0, 255]转变为[0.0, 1.0]。

reshape的作用是将图像的形状从[num examples, rows, columns, depth]转变为[num examples, rows*columns] (对于二维图片,depth为1)。

validation_size即为从训练集中抽取这么多来作为验证集。


变量定义好之后,接下来提取数据集。

先是图片文件:

with open(local_file, 'rb') as f:    train_images = extract_images(f)

看extract_images函数,从第52行开始:

with gzip.GzipFile(fileobj=f) as bytestream:    magic = _read32(bytestream)    if magic != 2051:      raise ValueError('Invalid magic number %d in MNIST image file: %s' %                       (magic, f.name))    num_images = _read32(bytestream)    rows = _read32(bytestream)    cols = _read32(bytestream)    buf = bytestream.read(rows * cols * num_images)    data = numpy.frombuffer(buf, dtype=numpy.uint8)    data = data.reshape(num_images, rows, cols, 1)    return data

如果这么看代码可能很难理解,但是如果清楚MNIST数据集文件的结构之后就好理解得多,对于MNIST的images文件:

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):offsettypevaluedescription000032 bit integer0x00000803(2051)magic number000432 bit integer60000number of images000832 bit integer28number of rows001232 bit integer28number of columns0016unsigned byte??pixel0017unsigned byte??pixel0018unsigned byte??pixel......   xxxxunsigned byte??pixel


代码中_read32()的定义在第33行,作用是从文件流中动态读取4位数据并转换为uint32的数据。

image文件的前四位为魔术码(magic number),只有检测到这4位数据的值和2051相等时,才代表这是正确的image文件,才会继续往下读取。接下来继续读取之后的4位,代表着image文件中,所包含的图片的数量(num_images)。再接着读4位,为每一幅图片的行数(rows),再后4位,为每一幅图片的列数(cols)。最后再读接下来的rows * cols * num_images位,即为所有图片的像素值。最后再将读取到的所有像素值装换为[index, rows, cols, depth]的4D矩阵。这样就将全部的image数据读取了出来。


同理,对于MNIST的labels文件:TRAINING SET LABEL FILE (train-labels-idx1-ubyte):offsettypevaluedescription000032 bit integer0x00000801(2049)magic number000432 bit integer60000number of items0008unsigned byte??label0009unsigned byte??label......   xxxxunsigned byte??label

再看代码,从第90行开始:

with gzip.GzipFile(fileobj=f) as bytestream:    magic = _read32(bytestream)    if magic != 2049:      raise ValueError('Invalid magic number %d in MNIST label file: %s' %                       (magic, f.name))    num_items = _read32(bytestream)    buf = bytestream.read(num_items)    labels = numpy.frombuffer(buf, dtype=numpy.uint8)    if one_hot:      return dense_to_one_hot(labels, num_classes)    return labels

同样的也是依次读取文件的魔术码以及标签总数,最后把所有图片的标签读取出来,成一个长度为num_items的1D的向量。不过代码中还有一个one_hot的部分,dense_to_one_hot的代码为:

  num_labels = labels_dense.shape[0]  index_offset = numpy.arange(num_labels) * num_classes  labels_one_hot = numpy.zeros((num_labels, num_classes))  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1  return labels_one_hot

正如文章开头提到one_hot的作用,这里将1D向量中的每一个值,编码成一个长度为num_classes的向量,向量中对应于该值的位置为1,其余为0,所以one_hot将长度为num_labels的向量编码为一个[num_labels, num_classes]的2D矩阵。

以上就是如何将MNIST数据文件中的images和labels分别提取出来的过程,与TensorFlow和deeplearning无关,但是我觉得对于MNIST数据集的了解,以及后面的一些才做还是很有帮助的。

4 0
原创粉丝点击