MNIST数据读取分析

来源:互联网 发布:用java代码表白 编辑:程序博客网 时间:2024/05/29 14:30

从input_data.py中获取


Input_data.py只是个过渡,真正是mnist.pyread_data_sets方法


具体分析read_data_sets方法


这里调用了DataSet,imagets/labels都是空,看看


这里重点看fake_data,设置了2个变量


又设了4,应是为了后面的方法准备


有用的方法,就是next_batch.


切回read_data_sets


这里调用了base.maybe_download


from tensorflow.python.platform import gfile

Gfile明显是处理文件下载的.


获取目录


直接将远程数据下载到本地,又调用urlretrieve_with_retry(url, filename=None)方法



这里执行了URL的文件下载,先不管


把临时文件转成正式文件,下载完成.


根据收到的文件名,打开文件.从中提取出train_images,又调用了解开image数据的方法


With...as... 这个语法是用来代替传统的try...finally语法的,防止打开时出错无处理.

magic=_read32(bytestream): 从byte流中获取.动用了


frombuffer: 读取图片数据的方法



如果第一个4位数,不是2051,抛出异常.


接下来,num_images,rows,cols的具体值.这里的bytestream应是指针顺序读取.

buf = bytestream.read(rows * cols * num_images)

data = numpy.frombuffer(buf, dtype=numpy.uint8)

读出真正的数据(先定义buf,再读取,npbuffer处理方式)

data = data.reshape(num_images, rows, cols, 1)

转置成arrays 4

这里,是把文件的内容全都读进内存,没有分batch.


获取LABEL



前面一样,2049判断文件的正确性

num_items: 接下来的值,为数据量

bytestream.read: 读取到buf

numpy.frombuffer: 存到labels

one_hot处理


num_labels = labels_dense.shape[0]

labels_dense1维矩阵,这里获取了矩阵的长度,相当于标签的数量了.

index_offset = numpy.arange(num_labels) * num_classes

index_offset: [0,10,20,30...]

labels_one_hot = numpy.zeros((num_labels, num_classes))

开出全0矩阵

labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1

labels_one_hot1维长度(flat)进行赋1,这样变成相应位置为1,其余为0

abels_dense.ravel():多维转1,用于取得标签的值

 

: one hot 的数据,是用偏移来算的,故而原LABEL是要0,1,2...这样的数据

经过以上处理,LABEL的所有数据,也都到内存中了.


数据切分


判断验证集数据量(validation_size,默认5000)是否小于样本总量train_images


把样本数据的前validation_size,当成验证数据集,之后的当学习集.


生成各自的数据集,又调用DataSet


assert: 下断言,不满足时跳出.

self._num_examples = images.shape[0]

获取样本数量


只处理黑白图片(depth=1)

3维转2(样本量,样本长*),相当于把图片的2维数据变成1维长条型.


把image的数据,int转变成float32

100* 1/255=100/255,类似归一化.


next_batch

获取批次数据


start = self._index_in_epoch

开始序号

self._index_in_epoch += batch_size

增加一个batch

if self._index_in_epoch > self._num_examples:

      # Finished epoch

      self._epochs_completed += 1

完成一轮

# Shuffle the data

perm = numpy.arange(self._num_examples)

numpy.random.shuffle(perm)

perm随机排序

self._images = self._images[perm]

self._labels = self._labels[perm]

获取乱序之后的数据

 

# Start next epoch

start = 0

self._index_in_epoch = batch_size

assert batch_size <= self._num_examples

设置开始与结束的位置

end = self._index_in_epoch

正式设置结束位置

return self._images[start:end], self._labels[start:end]

返回

 

总体思路:

1.如果刚开始(或新的一轮),就把顺序打乱,从头开始,batch获取量.

2.每次都获取一个batch的数据,直到结束,增加一个迭代.

 

这里的前提,是数据全装到内存了.

故而可知,数据的处理,完全是由编程人员自控的.









0 0