Tfrecords Guide

来源:互联网 发布:lol美服迅游加速器mac 编辑:程序博客网 时间:2024/06/13 05:23

摘要:Tfrecords相关操作

转自 Tfrecords Guide

本文代码及素材下载链接:http://download.csdn.net/download/u011939755/10117399
# encoding:utf-8# Created by gzp# Created on 2017/11/14 8:30# Function:tfrecords相关操作#=======================================================================================================================# 简介:# 将数据集存储为二进制文件更有利于计算机高效读取# *使用numpy获取任何图像的原始数据字节# *将数据集转换为tfrecord文件,而不需要定义计算图,只需要使用一些内置的张量函数# *定义一个模型,用于从创建的二进制文件中读取数据,并以随机的方式进行批量处理#=======================================================================================================================#=======================================================================================================================# Getting raw data bytes in numpy#=======================================================================================================================# import numpy as np# import skimage.io as io## baidu_img = io.imread('baidu.png')# io.imshow(baidu_img)# io.show()# # Let's convert the picture into string representation# # using the ndarray.tostring() function# baidu_string = baidu_img.tostring()# # Now let's convert the string back to the image# # Important: the dtype should be specified# # otherwise the reconstruction will be errorness# # Reconstruction is 1d, so we need sizes of image# # to fully reconstruct it.# reconstructed_cat_1d = np.fromstring(baidu_string, dtype=np.uint8)## # Here we reshape the 1d representation# # This is the why we need to store the sizes of image# # along with its serialized representation.# reconstructed_cat_img = reconstructed_cat_1d.reshape(baidu_img.shape)## # Let's check if we got everything right and compare# # reconstructed array to the original one.# compare = np.allclose(baidu_img, reconstructed_cat_img)# print('重构与原始图像数据是否一致:',compare)#=======================================================================================================================# Creating a .tfrecord file and reading it without defining a graph#=======================================================================================================================# # Get some image/annotation pairs for example# filename_pairs = [# ('./VOC2007/JPEGImages/000032.jpg',# './VOC2007/SegmentationClass/000032.png'),# ('./VOC2007/JPEGImages/000033.jpg',# './VOC2007/SegmentationClass/000033.png'),# ('./VOC2007/JPEGImages/000039.jpg',# './VOC2007/SegmentationClass/000039.png')#                  ]# # Important: We are using PIL to read .png files later.# # This was done on purpose to read indexed png files# # in a special way -- only indexes and not map the indexes# # to actual rgb values. This is specific to PASCAL VOC# # dataset data. If you don't want this type of behaviour# # consider using skimage.io.imread()# from PIL import Image# import numpy as np# import skimage.io as io# import tensorflow as tf### def _bytes_feature(value):#     return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))### def _int64_feature(value):#     return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))### tfrecords_filename = 'pascal_voc_segmentation.tfrecords'## writer = tf.python_io.TFRecordWriter(tfrecords_filename)## # Let's collect the real images to later on compare# # to the reconstructed ones# original_images = []## for img_path, annotation_path in filename_pairs:#     img = np.array(Image.open(img_path))#     annotation = np.array(Image.open(annotation_path))##     # The reason to store image sizes was demonstrated#     # in the previous example -- we have to know sizes#     # of images to later read raw serialized string,#     # convert to 1d array and convert to respective#     # shape that image used to have.#     height = img.shape[0]#     width = img.shape[1]##     # Put in the original images into array#     # Just for future check for correctness#     original_images.append((img, annotation))##     img_raw = img.tostring()#     annotation_raw = annotation.tostring()##     example = tf.train.Example(features=tf.train.Features(feature={#         'height': _int64_feature(height),#         'width': _int64_feature(width),#         'image_raw': _bytes_feature(img_raw),#         'mask_raw': _bytes_feature(annotation_raw)}))##     writer.write(example.SerializeToString())## writer.close()## ######################################################################################################################### reconstructed_images = []## record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)## for string_record in record_iterator:#     example = tf.train.Example()#     example.ParseFromString(string_record)##     height = int(example.features.feature['height']#                  .int64_list#                  .value[0])##     width = int(example.features.feature['width']#                 .int64_list#                 .value[0])##     img_string = (example.features.feature['image_raw']#                   .bytes_list#                   .value[0])##     annotation_string = (example.features.feature['mask_raw']#                          .bytes_list#                          .value[0])##     img_1d = np.fromstring(img_string, dtype=np.uint8)#     reconstructed_img = img_1d.reshape((height, width, -1))##     annotation_1d = np.fromstring(annotation_string, dtype=np.uint8)##     # Annotations don't have depth (3rd dimension)#     reconstructed_annotation = annotation_1d.reshape((height, width))##     reconstructed_images.append((reconstructed_img, reconstructed_annotation))# # Let's check if the reconstructed images match# # the original images## for original_pair, reconstructed_pair in zip(original_images, reconstructed_images):#     img_pair_to_compare, annotation_pair_to_compare = zip(original_pair,#                                                           reconstructed_pair)#     print(np.allclose(*img_pair_to_compare))#     print(np.allclose(*annotation_pair_to_compare))#=======================================================================================================================# Defining the graph to read and batch images from .tfrecords#=======================================================================================================================import tensorflow as tfimport skimage.io as ioIMAGE_HEIGHT = 384IMAGE_WIDTH = 384tfrecords_filename = 'pascal_voc_segmentation.tfrecords'def read_and_decode(filename_queue):    reader = tf.TFRecordReader()    _, serialized_example = reader.read(filename_queue)    features = tf.parse_single_example(        serialized_example,        # Defaults are not specified since both keys are required.        features={            'height': tf.FixedLenFeature([], tf.int64),            'width': tf.FixedLenFeature([], tf.int64),            'image_raw': tf.FixedLenFeature([], tf.string),            'mask_raw': tf.FixedLenFeature([], tf.string)        })    # Convert from a scalar string tensor (whose single string has    # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape    # [mnist.IMAGE_PIXELS].    image = tf.decode_raw(features['image_raw'], tf.uint8)    annotation = tf.decode_raw(features['mask_raw'], tf.uint8)    height = tf.cast(features['height'], tf.int32)    width = tf.cast(features['width'], tf.int32)    image_shape = tf.stack([height, width, 3])    annotation_shape = tf.stack([height, width, 1])    image = tf.reshape(image, image_shape)    annotation = tf.reshape(annotation, annotation_shape)    image_size_const = tf.constant((IMAGE_HEIGHT, IMAGE_WIDTH, 3), dtype=tf.int32)    annotation_size_const = tf.constant((IMAGE_HEIGHT, IMAGE_WIDTH, 1), dtype=tf.int32)    # Random transformations can be put here: right before you crop images    # to predefined size. To get more information look at the stackoverflow    # question linked above.    resized_image = tf.image.resize_image_with_crop_or_pad(image=image,                                                           target_height=IMAGE_HEIGHT,                                                           target_width=IMAGE_WIDTH)    resized_annotation = tf.image.resize_image_with_crop_or_pad(image=annotation,                                                                target_height=IMAGE_HEIGHT,                                                                target_width=IMAGE_WIDTH)    images, annotations = tf.train.shuffle_batch([resized_image, resized_annotation],                                                 batch_size=2,                                                 capacity=30,                                                 num_threads=2,                                                 min_after_dequeue=10)    return images, annotationsfilename_queue = tf.train.string_input_producer(    [tfrecords_filename], num_epochs=10)# Even when reading in multiple threads, share the filename# queue.image, annotation = read_and_decode(filename_queue)# The op for initializing the variables.init_op = tf.group(tf.global_variables_initializer(),                   tf.local_variables_initializer())with tf.Session()  as sess:    sess.run(init_op)    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    # Let's read off 3 batches just for example    for i in range(3):        img, anno = sess.run([image, annotation])        print(img[0, :, :, :].shape)        print('current batch')        # We selected the batch size of two        # So we should get two image pairs in each batch        # Let's make sure it is random        io.imshow(img[0, :, :, :])        io.show()        io.imshow(anno[0, :, :, 0])        io.show()        io.imshow(img[1, :, :, :])        io.show()        io.imshow(anno[1, :, :, 0])        io.show()    coord.request_stop()    coord.join(threads)

原创粉丝点击