tensorflow数据集制作/文件队列读取方式

来源:互联网 发布:windows桌面增强小工具 编辑:程序博客网 时间:2024/06/10 20:09

3种数据读取方式

TensorFlow程序读取数据一共有3种方法:
供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

以上3种方式官网中均有介绍
1. input = tf.placeholder(tf.float32) feed方式,先定义一个占位符,需要sess.run()的时候把数据传进去。
3.预加载,如下,讲数据保存在常量中,这个方法对于大数据不合适,内存资源不够。

training_data = ...training_labels = ...with tf.Session():  input_data = tf.constant(training_data)  input_labels = tf.constant(training_labels)

文件读取方式

使用原因:例如图片分类场景中,我们要使用自己的数据集,数据集比较大,需要动态的去添加数据,这样可以使用占位符,也可以利用文件读取的方式,这种方式更方便。下面我们介绍文件读取方式。(这里特别补充一个问题,队列是用来解决gpu空闲和内存问题的。所有的图片同时读到内存中是非常大的,内存可能承受不了,有了队列之后,每一次都从队列加载到内存队列中,这样就可以添加比较大的图片了。http://geek.csdn.net/news/detail/201552 详解TENSORFLOW读取机制)

官网中给出的一般步骤一共典型的文件读取管线会包含下面这些步骤:`文件名列表`可配置的 文件名乱序(shuffling)`可配置的 最大训练迭代数(epoch limit)`文件名队列`针对输入文件格式的阅读器`纪录解析器`可配置的预处理器`样本队列

这里介绍使用tensorflow的二进制格式来处理,一是因为这个二进制文件操作方便,而且网上有比较通用的处理流程,2是利用图片生成方便,200M的二进制文件可以很快速的生成,但是如果是csv文件的话,生成速度特别慢,而且200M的excel也打不开的。。(亲测)所以还是用二进制文件吧!

一下假设有了二进制文件,先别急,后面我会告诉大家如何去制作一个属于自己的数据集!

具体的队列描述,参考官网即可,我也没太看懂。
http://www.tensorfly.cn/tfdoc/how_tos/reading_data.html#AUTOGENERATED-preloaded-data
TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。QueueRunner类会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。

basePath = '/home/user/xxxxx'classes = {'c1','c2'}#生成数据集def create_record():    writer = tf.python_io.TFRecordWriter("train.tfrecords")    for index, name in enumerate(classes):        class_path = basePath +"/"+ name+"/"        for img_name in os.listdir(class_path):            img_path = class_path + img_name            img = Image.open(img_path)            img = img.resize((320, 240))            img_raw = img.tobytes() #将图片转化为原生bytes            #print index,img_raw            example = tf.train.Example(                features=tf.train.Features(                    feature={                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),                        'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))                    }                )            )            writer.write(example.SerializeToString())    writer.close()#读取二进制数据img, label = read_and_decode("../train.tfrecords")#分块处理,这里img_batch,就可以当做输入看待了,以后每次sess.run()相关操作都会取出一部分,此处相当于你自己写一个队列操作去feed x的数据。(个人理解)img_batch, label_batch = tf.train.shuffle_batch([img, label],                                            batch_size=4, capacity=2000,                                            min_after_dequeue=1000)#官网推荐处理模板# Create the graph, etc.init_op = tf.initialize_all_variables()# Create a session for running operations in the Graph.sess = tf.Session()# Initialize the variables (like the epoch counter).sess.run(init_op)# Start input enqueue threads.coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:    while not coord.should_stop():        # Run training steps or whatever        sess.run(train_op)except tf.errors.OutOfRangeError:    print 'Done training -- epoch limit reached'finally:    # When done, ask the threads to stop.    coord.request_stop()# Wait for threads to finish.coord.join(threads)sess.close()

制作自己的数据集

思路:TFRecords文件。
获取这种格式的文件方式为,首先将一般的数据格式填入Example protocol buffer中,再将 protocol buffer序列化为一个字符串,然后使用tf.python_io.TFRecordWriter类的相关方法将字符串写入一个TFRecords文件中。

#制作二进制数据def create_record():    writer = tf.python_io.TFRecordWriter("train.tfrecords")    for index, name in enumerate(classes):        class_path = basePath +"/"+ name+"/"        for img_name in os.listdir(class_path):            img_path = class_path + img_name            img = Image.open(img_path)            img = img.resize((320, 240))            img_raw = img.tobytes() #将图片转化为原生bytes            #print index,img_raw            example = tf.train.Example(                features=tf.train.Features(                    feature={                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),                        'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))                    }                )            )            writer.write(example.SerializeToString())    writer.close()

小问题

利用CNN分类的时候,出现了资源耗尽问题,网络铺的太大,显卡内存就用光了。mnist中使用28*28,我的图片320*240。主要原因来自于最后全连接层,参数过多。需要单独处理了!

原创粉丝点击