Tensorflow读取数据1

来源:互联网 发布:淘宝类目销售比列 编辑:程序博客网 时间:2024/05/18 20:07

原文地址: http://blog.csdn.net/u010911921/article/details/70577697

这段一直在用Tensorflow来做深度学习上的相关工作,然后对Tensorflow读取数据的方式进行实现。特地总结一下。首先是读取二进制图片数据,这里采用的是CIFAR-10的二进制数据

## 1.CIFAR-10数据集CIFAR-10数据集合是包含60000张`32*32*3`的图片,其中每个类包含6000张图片,总共10类。在这60000张图片中50000张是训练集合,10000张是测试集合。

其中二进制的图片保存的格式如下所示:

2.Tensorflow读取数据

从Tensorflow的官网可以看到从文件中读取数据的流程主要是一下步骤:

  1. The list of filenames
  2. (Optional) filename shuffling
  3. (Optional) epoch limit
  4. Filename queue
  5. A Reader for the file format
  6. A decoder for a record read by the reader
  7. (Optional) preprocessing
  8. Example queue

按照这样一个流程,首选应该将CIFAR-10的训练集和测试集合,生成文件名列表,然后在讲这个文件名列表传递给tf.train.string_input_producer函数创建一个用于保存文件名称的FIFO的队列,最后用tensor flow产生的reader从队列中读取数据。当reader读到数据就需要用tf.decode_raw函数对读取到的二进制数进行解码。

结束了上述操作,下面就需要采用另一个queue去batch together examples来为训练和测试提供数据。采用tf.train.shuffle_batch将上面生成的imagelabel传入函数即可完成。

3.开始训练

tf.train.shuffle_batch生成batch以后就开始利用tf.train.start_queue_runners函数启动队列,然后开始整个计算图,官网给的建议是如下形式:

init_op = tf.global_variables_initializer()with tf.Session as sess:    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(sess= sess,coord = coord)    try:        while not coord.should_stop():            #run training steps or whatever            sess.run(train_op)    except tf.errors.OutOfRangeError:        print('Done training --epoch limit reached')    finally:        # when done,ask the threads to stop        coord.request_stop()    coord.join(threads)

4.代码实现

在神经网络的训练中由于每训练k步以后就会对网络进行一次测试,所以需要在上述步骤中,增加动态选择文件名称队列这样一个过程,可以由tf.QueueBase.from_list函数进行实现,然后reader从返回的文件名称队列中读取数据。

整个过程的实现如下所示:

#!/usr/bin/env python3# --*-- encoding:utf-8 --*--import tensorflow as tfimport numpy as npimport osdef read_cifar10(data_dir,is_traing,batch_size,shuffle):    """    :param data_dir:数据保存路径    :param is_traing:True从训练集获取数据,False从测试集获取数据    :param batch_size:  batch_size的大小    :param shuffle: bool,是否进行shuffle操作    :return:    """    img_width = 32    img_height = 32    img_depth = 3    label_bytes = 1    img_bytes = img_height * img_width *img_depth    with tf.name_scope("input") as scope:        #训练集合的文件列表        train_filenames = [os.path.join(data_dir,                                        'data_batch_%d.bin'%ii) for ii in np.arange(1,6)]        #测试集合的文件列表        val_filenames = [os.path.join(data_dir,'test_batch.bin')]        #训练集和测试集合的文件名称队列        train_queue = tf.train.string_input_producer(train_filenames)        val_queue = tf.train.string_input_producer(val_filenames)        #挑选文件队列,实现training的过程中测试        queue_select = tf.cond(is_traing,                               lambda :tf.constant(0),                               lambda :tf.constant(1) )        queue = tf.QueueBase.from_list(queue_select,[train_queue,val_queue])        #从队列中读取固定长度的数据        reader = tf.FixedLengthRecordReader(label_bytes+img_bytes)        key,value = reader.read(queue)        recode_bytes = tf.decode_raw(value,tf.uint8)        #获取label        label = tf.slice(recode_bytes,[0],[label_bytes])        label = tf.cast(label,tf.int32)        #获取image        image_raw = tf.slice(recode_bytes,[label_bytes],[img_bytes])        image_raw = tf.reshape(image_raw,[img_depth, img_height, img_width])        image = tf.transpose(image_raw,[1,2,0])        image = tf.cast(image,tf.float32)        #对每一张图片进行标准化操作,可选操作此处可以进行对图片的各种操作        image = tf.image.per_image_standardization(image)        if shuffle:            images, label_batch= tf.train.shuffle_batch([image,label],                                                   batch_size=batch_size,                                                   num_threads=16,                                                   capacity=512+3*batch_size,                                                   min_after_dequeue=512,                                                   allow_smaller_final_batch=True)        else:            images, label_batch = tf.train.batch([image, label],                                            batch_size=batch_size,                                            num_threads=16,                                            capacity=512 + 3*batch_size,                                            allow_smaller_final_batch=True)        label_batch = tf.cast(label_batch,tf.int32)        return images,label_batch

整个过程是采用VGG-16的网络模型进行训练的,在迭代16000次,tensorboard展示的结果如图所示:

code下载地址https://github.com/ZhichengHuang/LearnTensorflowCode

参考资料:

  • https://www.tensorflow.org/versions/r1.1/programmers_guide/reading_data

  • http://stackoverflow.com/questions/41162955/tensorflow-queues-switching-between-train-and-validation-data

0 0
原创粉丝点击