TensorFlow源码阅读——tensorflow.contrib.learn.python.learn.datasets目录

来源:互联网 发布:域名证书生成器源码 编辑:程序博客网 时间:2024/05/17 05:01

minist.py

此模块下载并读取MNIST数据。导入的tensorflow模块为

from tensorflow.contrib.learn.python.learn.datasets import basefrom tensorflow.python.framework import dtypes

全局变量

SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘#在read_data_sets函数中下载数据集时会用到

函数

  • _read32(bytestream)
  • extract_images(f):从下载的文件的文件对象中提取图像并转换成4D uint8类型的nparray [index, y, x, depth]
    参数:
    f:可以传给gzip读取器的文件对象(压缩文件)
    返回:
    data:4D的uint8类型的图像nparray [index, y, x, depth]
    异常:
    如果文件不以2051开头的话,捕获异常

  • dense_to_one_hot(labels_dense, num_classes):把类别标签从标量转换成one-hot向量
    参数:
    labels_dense:一维向量,每个元素值表示类别标签
    num_classes:类别总数
    返回:
    labels_one_hot:nparray,shape=[labels_dense.shape[0], num_classes]。每一行只有一个位置上的元素值为1

  • extract_labels(f, one_hot=False, num_classes=10):从文件对象中提取标签并转换成1D uint8类型的nparray [index]
    参数
    f:可以传递给gzip读取器的文件对象
    ont_hot:标签是否需要表示成one-hot形式
    num_classes:类别总数,在one_hot为True时会用到
    返回:
    labels:1D uint8 nparray
    异常:
    如果文件不以2049开头的话,捕获异常。

  • read_data_set(train_dir, fake_data=False, one_hot=False, dtype=dtype.float32, reshape=True, validation_size=5000):根据fake_data的真假,分别创建train、validation、test的Dataset,并传给base.Datasets()函数
    参数:
    train_dir:下载文件的保存目录
    fake_data:为True时,定义一个嵌套函数fake(),返回一个Dataset对象,用于创建train、validation和test数据集,并返回base.Datasets(train, validation, test);否则指定训练和测试数据名,根据真实数据构建Dataset对象
    one_hot:创建Dataset对象时的参数
    dtype:创建Dataset对象时的参数
    reshape:创建Dataset对象时的参数
    validation_size:当读取真实数据时,没有validation数据,所以要从训练数据中划分出验证集,这里指定验证集的数量
    返回
    异常
    在指定的验证集数量小于0或者大于训练集的数量(因为要从训练集中划分出一部分作为验证集,自然指定的验证集数量不能超过训练集数量)时,捕获异常

  • load_mnist(train_dir='MNIST-data'):不用管

Dataset类

传递给base.Datasets()的参数类型。

成员变量:
- self._num_examples
- self.one_hot
- self._images
- self._labels
- self._epochs_completed:记录数据遍历次数
- self._index_in_epoch:取batch时用到,记录一次数据遍历过程中现在读到第几个数据了

成员函数:

  • __init__(self, images, labels, fake_data=False, one_hot=False, dtype=dtypes.float32, reshape=True):根据参数选择是否reshape以及像素点取值是否转换成[0,1]之间
    参数:
    images:
    labels:
    fake_data:当为True时,设置样本总数为10000;否则根据images、labels创建数据集
    one_hot:只有在fake_data为True时才使用此参数
    dtype:可以是uint8或者float32,前者表示像素点取值仍在[0, 255],后者表示将像素点resize到[0, 1]
    reshape:fake_data为False时,会判断reshape是否为真,是的话,在images的depth维度为1时,将images reshape成[num_examples, rows*columns]
    返回:
    异常:
    当dtype取值不是uint8或者float32时,捕获类型异常错误;

  • next_batch(self, batch_size, fake_data=False, shuffle=True):从数据集中返回下一个batch,注意在第一次遍历读到最后一个batch时,如果剩下的数据不够一个batch,那么打乱数据后从头补够。
    参数:
    batch_size:
    fake_data:为True时,返回的batch每个图像都是784维全为1的,每个标签都是0
    shuffle:在第一次取第一个batch时,会把self._image和self._labels打乱;在一次遍历结束后会再次进行打乱
    返回:
    batch_images:
    batch_labels:fake为True时,返回人造数据;fake为False时,返回从slef._images和self._labels中读取的下一个batch。

base.py

函数

  • maybe_download()

  • Datasets()

0 0
原创粉丝点击