Tensorflow-tfrecord数据

来源:互联网 发布:小米网络电视安装 编辑:程序博客网 时间:2024/05/16 04:35

使用的数据:https://download.pytorch.org/tutorial/hymenoptera_data.zip

1、图像—>tfrecode

#!/usr/bin/python3# -*- coding: UTF-8 -*-import tensorflow as tfimport globfrom itertools import groupbyfrom collections import defaultdictfrom PIL import Imageimport numpy as np# 将满足目录的所有.jpg文件的路径放置在image_filenames列表中# image_filenames存放所有满足条件的jpg的路径image_filenames = glob.glob("./hymenoptera_data/*/*.jpg")  # ==> <class 'list'>sess = tf.InteractiveSession()training_dataset = defaultdict(list)testing_dataset = defaultdict(list)# Split up the filename into its breed and corresponding filename. The breed is found by taking the directory name# 将文件名分解为品种和相应的文件名(文件对应的路径),品种对应文件夹名称(作为标签)image_filename_with_breed = map(lambda filename: (filename.split("/")[-2], filename), image_filenames)  # Linux "/"# Group each image by the breed which is the 0th element in the tuple returned abovefor dog_breed, breed_images in groupby(image_filename_with_breed, lambda x: x[0]):    # Enumerate each breed's image and send ~20% of the images to a testing set    for i, breed_image in enumerate(breed_images):        if i % 5 == 0:            testing_dataset[dog_breed].append(breed_image[1])  # dog_breed对应文件名,即标签,breed_image[1]对应jpg的路径        else:            training_dataset[dog_breed].append(breed_image[1])    # Check that each breed includes at least 18% of the images for testing    breed_training_count = len(training_dataset[dog_breed])    breed_testing_count = len(testing_dataset[dog_breed])    assert round(breed_testing_count / (breed_training_count + breed_testing_count),                 2) > 0.18, "Not enough testing images."# 图像--->tfrecodedef write_records_file(dataset, record_location):    """    Fill a TFRecords file with the images found in `dataset` and include their category.    Parameters    ----------    dataset : dict(list)      Dictionary with each key being a label for the list of image filenames of its value.    record_location : str      Location to store the TFRecord output.    """    writer = None    # Enumerating the dataset because the current index is used to breakup the files if they get over 100    # images to avoid a slowdown in writing.    # 枚举dataset,因为当前索引用于对文件进行划分,每隔100幅图像,训练样本的信息就被写入到一个新的Tfrecode文件中,以加快操作的进程    current_index = 0    for breed, images_filenames in dataset.items():        for image_filename in images_filenames:            if current_index % 10 == 0:                if writer:                    writer.close()                record_filename = "{record_location}-{current_index}.tfrecords".format(                    record_location=record_location,                    current_index=current_index)                writer = tf.python_io.TFRecordWriter(record_filename)            current_index += 1            '''            # 方法一,使用PIL            try:                image=Image.open(image_filename)                image=image.convert('L') #转成灰度图                image=image.resize((250,151))            except:                print(image_filename)                continue            image_bytes = sess.run(tf.cast(np.array(image), tf.uint8)).tobytes()            '''            # 方法二、使用tf.image.decode_jpeg            # 在ImageNet的狗的图像中,有少量无法被Tensorflow识别的JPEG的图像,利用try/catch可以将这些图像忽略            try:                image_file = tf.read_file(image_filename)                image = tf.image.decode_jpeg(image_file)                # 转换成灰度图可以减少处理的计算量和内存占用,但这不是必须的                grayscale_image = tf.image.rgb_to_grayscale(image)  # 转成灰度                resized_image = tf.image.resize_images(grayscale_image, (250, 151))  # 图像大小固定为 250x151# resized_image=tf.image.resize_image_with_crop_or_pad(grayscale_image,250,151)                # 这里之所以使用tf.cast,是因为虽然尺寸更改后的图像数据类型是浮点型,但RGB尚未转换到[0,1)区间内                image_bytes = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()            except:                print(image_filename)                continue            # https://en.wikipedia.org/wiki/One-hot            # 将标签按字符串存储较高效,推荐的做法是将其转换为整数索引或独热编码的秩1张量            #            '''            image_label = tf.case({tf.equal(breed, tf.constant('n02085620-Chihuahua')): lambda: tf.constant(0),                              tf.equal(breed, tf.constant('n02096051-Airedale')): lambda: tf.constant(1),                              }, lambda: tf.constant(-1), exclusive=True)            image_label = sess.run(image_label)            '''            image_label = breed.encode("utf-8")            example = tf.train.Example(features=tf.train.Features(feature={                'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_label])),                # 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[image_label])),                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))            }))            writer.write(example.SerializeToString())    if writer:        writer.close()if __name__ == "__main__":    write_records_file(training_dataset, "./output/training-images/training-images")    write_records_file(testing_dataset, "./output/testing-images/testing-images")

2、tfrecode—>numpy

import tensorflow as tfimport glob# from itertools import groupby# from collections import defaultdict# from PIL import Image# import numpy as np# Load Imagesdef load_images_from_tfrecord(tfrecord_file):    filename_queue = tf.train.string_input_producer(        tf.train.match_filenames_once(tfrecord_file)) # 加载多个Tfrecode文件    reader = tf.TFRecordReader()    _, serialized = reader.read(filename_queue)    features = tf.parse_single_example(        serialized,        features={            'label': tf.FixedLenFeature([], tf.string),            # 'label': tf.FixedLenFeature([], tf.int64),            'image': tf.FixedLenFeature([], tf.string),        })    record_image = tf.decode_raw(features['image'], tf.uint8)    # Changing the image into this shape helps train and visualize the output by converting it to    # be organized like an image.    # 修改图像的形状有助于训练和输出的可视化    image = tf.reshape(record_image, [250, 151, 1])    label = tf.cast(features['label'], tf.string)    # label = tf.cast(features['label'], tf.int64)    # label string-->int 0,1 标签    label = tf.case({tf.equal(label, tf.constant('n02085620-Chihuahua')): lambda: tf.constant(0),                            tf.equal(label, tf.constant('n02096051-Airedale')): lambda: tf.constant(1),                            }, lambda: tf.constant(-1), exclusive=True)    min_after_dequeue = 10    batch_size = 3    capacity = min_after_dequeue + 3 * batch_size    image_batch, label_batch = tf.train.shuffle_batch(        [image, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)    '''    # Find every directory name in the imagenet-dogs directory (n02085620-Chihuahua, ...)    labels = list(map(lambda c: c.split("\\")[-2], glob.glob(imagepath))) # 找到目录名(标签) linux使用 "/"    # Match every label from label_batch and return the index where they exist in the list of classes    # 匹配每个来自label_batch的标签并返回它们在类别列表中的索引    train_labels = tf.map_fn(lambda l: tf.where(tf.equal(labels, l))[0,0:1][0], label_batch, dtype=tf.int64)    '''    # Converting the images to a float of [0,1) to match the expected input to convolution2d    # 将图像转换为灰度值位于[0,1)的浮点类型,    float_image_batch = tf.image.convert_image_dtype(image_batch, tf.float32)    return float_image_batch,label_batchif __name__=="__main__":    img_batch,label_batch=load_images_from_tfrecord("output/training-images/*.tfrecords")    with tf.Session() as sess:        tf.global_variables_initializer().run()        tf.local_variables_initializer().run()        coord = tf.train.Coordinator()        threads = tf.train.start_queue_runners(sess=sess, coord=coord)        try:            while not coord.should_stop():                for i in range(100):                    val, l = sess.run([img_batch, label_batch])                    if i%5==0:                        print(val.shape, l.shape,l)                else:                    break        except tf.errors.OutOfRangeError:            print('Done training -- epoch limit reached')        finally:            coord.request_stop()        coord.join(threads)
原创粉丝点击