TensorFlow源码阅读——tensorflow.contrib.learn.python.learn.datasets目录
来源:互联网 发布:域名证书生成器源码 编辑:程序博客网 时间:2024/05/17 05:01
minist.py
此模块下载并读取MNIST数据。导入的tensorflow模块为
from tensorflow.contrib.learn.python.learn.datasets import basefrom tensorflow.python.framework import dtypes
全局变量
SOURCE_URL = ‘http://yann.lecun.com/exdb/mnist/‘#在read_data_sets函数中下载数据集时会用到
函数
_read32(bytestream)
:extract_images(f)
:从下载的文件的文件对象中提取图像并转换成4D uint8类型的nparray [index, y, x, depth]
参数:
f:可以传给gzip读取器的文件对象(压缩文件)
返回:
data:4D的uint8类型的图像nparray [index, y, x, depth]
异常:
如果文件不以2051开头的话,捕获异常dense_to_one_hot(labels_dense, num_classes)
:把类别标签从标量转换成one-hot向量
参数:
labels_dense:一维向量,每个元素值表示类别标签
num_classes:类别总数
返回:
labels_one_hot:nparray,shape=[labels_dense.shape[0], num_classes]。每一行只有一个位置上的元素值为1extract_labels(f, one_hot=False, num_classes=10)
:从文件对象中提取标签并转换成1D uint8类型的nparray [index]
参数:
f:可以传递给gzip读取器的文件对象
ont_hot:标签是否需要表示成one-hot形式
num_classes:类别总数,在one_hot为True时会用到
返回:
labels:1D uint8 nparray
异常:
如果文件不以2049开头的话,捕获异常。read_data_set(train_dir, fake_data=False, one_hot=False, dtype=dtype.float32, reshape=True, validation_size=5000)
:根据fake_data的真假,分别创建train、validation、test的Dataset,并传给base.Datasets()函数
参数:
train_dir:下载文件的保存目录
fake_data:为True时,定义一个嵌套函数fake(),返回一个Dataset对象,用于创建train、validation和test数据集,并返回base.Datasets(train, validation, test);否则指定训练和测试数据名,根据真实数据构建Dataset对象
one_hot:创建Dataset对象时的参数
dtype:创建Dataset对象时的参数
reshape:创建Dataset对象时的参数
validation_size:当读取真实数据时,没有validation数据,所以要从训练数据中划分出验证集,这里指定验证集的数量
返回:
异常:
在指定的验证集数量小于0或者大于训练集的数量(因为要从训练集中划分出一部分作为验证集,自然指定的验证集数量不能超过训练集数量)时,捕获异常load_mnist(train_dir='MNIST-data')
:不用管
Dataset类
传递给base.Datasets()的参数类型。
成员变量:
- self._num_examples
- self.one_hot
- self._images
- self._labels
- self._epochs_completed
:记录数据遍历次数
- self._index_in_epoch
:取batch时用到,记录一次数据遍历过程中现在读到第几个数据了
成员函数:
__init__(self, images, labels, fake_data=False, one_hot=False, dtype=dtypes.float32, reshape=True)
:根据参数选择是否reshape以及像素点取值是否转换成[0,1]之间
参数:
images:
labels:
fake_data:当为True时,设置样本总数为10000;否则根据images、labels创建数据集
one_hot:只有在fake_data为True时才使用此参数
dtype:可以是uint8或者float32,前者表示像素点取值仍在[0, 255],后者表示将像素点resize到[0, 1]
reshape:fake_data为False时,会判断reshape是否为真,是的话,在images的depth维度为1时,将images reshape成[num_examples, rows*columns]
返回:
异常:
当dtype取值不是uint8或者float32时,捕获类型异常错误;next_batch(self, batch_size, fake_data=False, shuffle=True)
:从数据集中返回下一个batch,注意在第一次遍历读到最后一个batch时,如果剩下的数据不够一个batch,那么打乱数据后从头补够。
参数:
batch_size:
fake_data:为True时,返回的batch每个图像都是784维全为1的,每个标签都是0
shuffle:在第一次取第一个batch时,会把self._image和self._labels打乱;在一次遍历结束后会再次进行打乱
返回:
batch_images:
batch_labels:fake为True时,返回人造数据;fake为False时,返回从slef._images和self._labels中读取的下一个batch。
base.py
函数
maybe_download()
Datasets()
- TensorFlow源码阅读——tensorflow.contrib.learn.python.learn.datasets目录
- 运行tensorflow程序是提示‘ImportError: No module named contrib.learn.python.learn.datasets’
- tensorflow之tf.contrib.learn Quickstart
- TensorFlow-4: tf.contrib.learn 快速入门
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- TensorFlow源码阅读——tensorflow/examples/tutorials/minist目录
- TensorFlow学习笔记6----tf.contrib.learn Quickstart
- [TensorFlow实战练习]3-高层API-tf.contrib.learn练习
- tensorflow学习笔记(六):TF.contrib.learn大杂烩
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- tensorflow中tf.contrib.learn.preprocessing.VocabularyProcessor理解
- 深度学习笔记——深度学习框架TensorFlow(八)[Logging and Monitoring Basics with tf.contrib.learn]
- 深度学习笔记——深度学习框架TensorFlow(九)[Building Input Functions with tf.contrib.learn]
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- 05:Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数
- TensorFlow-5: 用 tf.contrib.learn 来构建输入函数
- NSLog自定义
- Unity3D
- 得到综合(二)
- bootstap后台管理增删改
- 无题
- TensorFlow源码阅读——tensorflow.contrib.learn.python.learn.datasets目录
- Android之获取外部存储空间解释
- setNeedsDisplay和setNeedsLayout
- 开始用博客记录技术知识!
- linux sudoers理解
- Spark性能优化之道——解决Spark数据倾斜(Data Skew)的N种姿势
- AndroidStudio编译错误:Error: null value in entry: blameLogFolder=null
- 正则表达式手册
- centOS7忘记密码