AI challenger 场景分类(1) 生成tfrecord文件

来源:互联网 发布:php部署到apache 编辑:程序博客网 时间:2024/05/16 07:01

用时:30 min
原图大小:3.5 G
tfrecord文件大小:65.3 G (amazing! 注意原图是jpg压缩的)

# -*- coding: utf-8 -*-"""Created on Thu Sep  7 19:25:38 2017@author: waynehttp://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html"""import jsonimport osfrom PIL import Imageimport numpy as npimport matplotlib.pyplot as pltimport tensorflow as tfimport datetimerecord_PATH = 'ai_challenger_scene_train_20170904/'   # 目标文件夹tfrecord_file = record_PATH + 'train.tfrecord'writer = tf.python_io.TFRecordWriter(tfrecord_file)def _int64_feature(value):  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def get_image_binary(filename):    """ You can read in the image using tensorflow too, but it's a drag        since you have to create graphs. It's much easier using Pillow and NumPy    """    image = Image.open(filename)    image = np.asarray(image, np.uint8)    shape = np.array(image.shape, np.int32)    return shape, image.tobytes() # convert image to raw data bytes in the array.def write_to_tfrecord(label, shape, binary_image, tfrecord_file):    """ This example is to write a sample to TFRecord file. If you want to write    more samples, just use a loop.    """    # write label, shape, and image content to the TFRecord file    example = tf.train.Example(features=tf.train.Features(feature={                'label': _int64_feature(label),                'h': _int64_feature(shape[0]),                'w': _int64_feature(shape[1]),                'c': _int64_feature(shape[2]),                'image': _bytes_feature(binary_image)                }))    writer.write(example.SerializeToString())def write_tfrecord(label, image_file, tfrecord_file):    shape, binary_image = get_image_binary(image_file)    write_to_tfrecord(label, shape, binary_image, tfrecord_file)with open('ai_challenger_scene_train_20170904/scene_train_annotations_20170904.json', 'r') as f: #label文件    label_raw = json.load(f)def file_name2(file_dir):   #特定类型的文件    L=[]       image = []    for root, dirs, files in os.walk(file_dir):          for file in files:              if os.path.splitext(file)[1] == '.jpg':                   L.append(os.path.join(root, file))                image.append(file)    return L, imagepath, image = file_name2('ai_challenger_scene_train_20170904/scene_train_images_20170904') #图片目录'''存入tfrecords'''label = {}for item in label_raw:    label[item['image_id']] = int(item['label_id'])starttime = datetime.datetime.now()#long runningnum = len(path)for i in range(num):    write_tfrecord(label[image[i]], path[i], tfrecord_file)    if i%1000==0:        print(i)writer.close()endtime = datetime.datetime.now()print (endtime - starttime).seconds
阅读全文
0 0