TensorFlow读取数据
来源:互联网 发布:java 一个对象的大小 编辑:程序博客网 时间:2024/06/05 07:09
本文介绍如何使用TensorFlow来读取图片数据,主要介绍写入TFRecord文件再读取和直接使用队列来读取两种方式。假设我们图片目录结构如下:
|---a| |---1.jpg| |---2.jpg| |---3.jpg||---b| |---1.jpg| |---2.jpg| |---3.jpg||---c| |---1.jpg| |---2.jpg| |---3.jpg
1 使用TFRecoder
思路:思路:使用
TFRecod
主要是把每张图片及其对应的label写入到一个tfrecode
文件中。tfrecode
以二进制形式保存,其中内部使用了protobuf
定义协议,即定义格式序列化为二进制。我们可以使用tf
提供的tf.train.Example
来指定序列化格式。将a目录中所有的文件的label
指定为a
,另外两个目录b
、c
同理。
代码如下:
def build_data(dir,file_str,map_str): ''' :param dir: 根目录,dir下所有子目录名称为label :param file_str: 导出的tfrecorde文件 :param map_str: 数字序号0~n与label映射关系保存路径 :return: ''' files=os.listdir(dir); writer = tf.python_io.TFRecordWriter(file_str) # 要生成的文件 # 由于tf.train.Feature只能取float、int和bytes,因此需要将label映射到int,保存到文件 map_file = open(map_str,'w') for index,label in enumerate(files): #遍历文件夹 data_dir = os.path.join(dir,label) map_file.write(str(index) + ":" + label + "\n") for img_name in os.listdir(data_dir): #遍历图片 img_path=os.path.join(data_dir,img_name) img = Image.open(img_path) #读取图片 img = img.resize((256, 256)) #将图片宽高转为256*256 img_raw=img.tobytes() #图片转为字节 example=tf.train.Example(features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])) })) writer.write(example.SerializeToString()) # 序列化为字符串并写入文件 writer.close() map_file.close();
接下来是读取tfrecord文件。注意读取时label、img名称及类型要一致:
def read_data(file_str): # 根据文件名生成一个队列 file_path_queue = tf.train.string_input_producer([file_str]) reader = tf.TFRecordReader() _, serialized_example = reader.read(file_path_queue) # 返回文件名和文件 features = tf.parse_single_example(serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img': tf.FixedLenFeature([], tf.string), }) label = tf.cast(features['label'], tf.int64) # 读取label img = tf.decode_raw(features['img'], tf.uint8) img = tf.reshape(img, [256, 256, 3]) #将维度转为256*256的3通道 img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #将图片中的数据转为[-0.5,0.5] return img, label
接下来看看如何使用:
build_data("D:/test","D:/data/tf.tfrecorde","D:/data/map.txt")img, label =read_data("D:/data/tf.tfrecorde")#使用shuffle_batch可以随机打乱输入img_batch, label_batch = tf.train.shuffle_batch([img, label], batch_size=30, capacity=2000, min_after_dequeue=1000)init = tf.initialize_all_variables()with tf.Session() as sess: sess.run(init) threads = tf.train.start_queue_runners(sess=sess) for i in range(3): imgs, labels= sess.run([img_batch, label_batch]) #我们也可以根据需要对val, l进行处理 print(imgs.shape, labels)
运行结果如下:
(30, 256, 256, 3) [1 2 2 1 1 2 2 1 0 1 0 1 0 0 2 0 0 0 2 1 1 1 1 0 0 1 2 1 2 0](30, 256, 256, 3) [2 1 1 0 0 1 1 0 2 2 2 0 0 0 0 2 1 0 0 2 0 0 2 2 2 1 0 1 0 2](30, 256, 256, 3) [2 0 2 0 1 2 1 2 2 1 0 2 0 0 2 2 2 1 1 1 1 1 0 0 2 0 2 2 0 0]
从结果可以看出,虽然我们提供的图片只有9张。每一类各3张,但是能读取30*30*30张出来,这主要是通过循环读取得到的。也就是说数量上虽然增加了,但实际上也就是那9张图片。
2 不使用TFRecord
TFRecord
适合将标签、图片数据等其他相关的数据一起封装到一个对象,然后逐个读取。有时候,我们并不需要标签,只需要对图片读取。那么可以考虑之间从路径队列中读取,而不需要转到TFRecord
文件。
直接上代码:
def read_data(dir ): ''' :param dir: 图片根目录 ''' input_paths = glob.glob(os.path.join(dir, "*.jpg")) decode = tf.image.decode_jpeg if len(input_paths) == 0: #如果不存在jpg图片,则遍历png图片 input_paths = glob.glob(os.path.join(dir, "*.png")) decode = tf.image.decode_png if len(input_paths) == 0: #如果png图片不存在,抛出异常 raise Exception("input_dir contains no image files") #产生文件路径队列,并且打乱顺序 path_queue = tf.train.string_input_producer(input_paths, shuffle=True) reader = tf.WholeFileReader() #创建读取文件对象 paths, contents = reader.read(path_queue) #从队列中读取 img_raw = decode(contents) # 将图片缩小到256*256,如果在此之前对图片预处理(放缩),那么这一步可省略 img_raw = tf.image.resize_images(img_raw, [256, 256]) img_raw = tf.image.convert_image_dtype(img_raw, dtype=tf.float32) img_raw.set_shape([256, 256, 3])#设置shape return img_raw
接下来看看如何使用:
img = read_data("D:/test/*" )img_batch = tf.train.batch([img], batch_size=30)init = tf.initialize_all_variables()with tf.Session() as sess: sess.run(init) threads = tf.train.start_queue_runners(sess=sess) for i in range(3): imgs = sess.run( img_batch ) print(imgs.shape )
看看运行结果:
(30, 256, 256, 3)(30, 256, 256, 3)(30, 256, 256, 3)
阅读全文
0 0
- tensorflow爬坑行:数据读取
- Tensorflow图片数据读取
- Tensorflow读取数据
- Tensorflow读取数据
- tensorflow读取文件数据
- Tensorflow读取数据1
- Tensorflow数据读取方式
- Tensorflow图片数据读取
- tensorflow图片数据读取
- Tensorflow数据读取方法
- TensorFlow读取tfrecords数据
- tensorflow读取数据
- TensorFlow数据读取
- tensorflow读取数据
- tensorflow 数据读取笔记
- TensorFlow数据读取
- TensorFlow读取数据
- tensorflow 数据读取
- apt yum
- java 入门
- 设计模式之职责链模式
- SQLite数据库操作
- 从mysql数据表中随机取出一条记录
- TensorFlow读取数据
- 第二个任务
- 读书笔记 | 为什么从世界500强CEO、政界要员,到著名演员都用这个方法来提高效率?
- JS —— 笔记,$("document").ready() 中ajax 与 $.ajax() 及同步异步优先级问题
- Java并发编程:volatile关键字解析
- java基础数据结构分析
- git fetch /rebase /merge 使用
- X-Frame-Options响应头缺失漏洞
- java BufferedImage简单图片写字一个小例子