生成TFRecords文件代码(最终版,亲测可用)

来源:互联网 发布:淘宝买家最多买多少 编辑:程序博客网 时间:2024/06/06 13:11

直接上代码,然后底下补充注意事项。亲测可用


#coding:utf-8import tensorflow as tfimport osimport os.pathos.environ["CUDA_VISIBLE_DEVICES"] = "1"def _int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))rootdir = "G:\\ZhangSG\\TFRecords\\indoor_scene"TFRfilename = "G:\\ZhangSG\\TFRecords\\indoor_scene.tfrecords"defined_label = ['airport_inside','artstudio','auditorium','bakery','bar','bathroom','bedroom','bookstore','bowling','buffet','casino','children_room','church_inside','classroom','cloister','closet','clothingstore','computerroom','concert_hall','corridor','deli','dentaloffice','dining_room','elevator','fastfood_restaurant','florist','gameroom','garage','greenhouse','grocerystore','gym','hairsalon','hospitalroom','inside_bus','inside_subway','jewelleryshop','kindergarden','kitchen','laboratorywet','laundromat','library','livingroom','lobby','locker_room','mall','meeting_room','movietheater','museum','nursery','office','operating_room','pantry','poolinside','prisoncell','restaurant','restaurant_kitchen','shoeshop','stairscase','studiomusic','subway','toystore','trainstation','tv_studio','videostore','waitingroom','warehouse','winecellar']# get the labelID (0 ~ category_num -1) or -1 if label not founddef convert_filename_to_labelID(filename,defined_label):    # get the label numbers    label_num = len(defined_label)    labelid = -1;    # loop the defined labels to find the label name that matches current filename    for i in range(0,label_num):       if defined_label[i] in filename:            labelid=i            break    return labelidwriter = tf.python_io.TFRecordWriter(TFRfilename)count=0with tf.Session() as sess:    for parent,dirnames,filenames in os.walk(rootdir):    #三个参数:分别返回1.父目录 2.所有文件夹名字(不含路径) 3.所有文件名字        for filename in filenames:                        #输出文件信息                   if "jpg" in filename:              labelID = convert_filename_to_labelID(filename,defined_label)                  if (labelID>=0) and (labelID<len(defined_label)):                      image_dir = parent+"\\"+filename                  image_raw_data_jpg = tf.gfile.FastGFile(image_dir, 'rb').read()                  img_data_jpg = tf.image.decode_jpeg(image_raw_data_jpg)                  img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.float32)                  resized_image = tf.image.resize_images(img_data_jpg, [200, 200])                  image_raw_data = sess.run(tf.cast(resized_image, tf.uint8)).tobytes()                                                      if(len(image_raw_data)==0):                     continue                  example = tf.train.Example(features=tf.train.Features(feature={                      # 包装为可以训练的数据                      'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[labelID])),                      'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw_data]))                                           }))                  count=count+1                  print("文件"+filename+"生成成功,已生成%d个文件"%count)                  writer.write(example.SerializeToString())    writer.close()    print ("TFRecord文件已保存。共%d个文件"%count)    

如果想用这段代码,需要改动几个地方:

1. os.environ["CUDA_VISIBLE_DEVICES"] = "1" 如果只有一个GPU,这句话不用要。

2. root_dir是存放训练集(或者测试集)图片的地方,也就是待生成tfrecords文件的那些图片。TFRfilename是生成的tfrecords文件所在的路径和文件名。本条注意事项,我这两个路径都写成了绝对路径,可以按照自己需求改动。

3. 保证所有图片的名字都含有标签。如果不含有,简单方法是选中该分类下所有文件,全选,右键,直接输入分类名,如kitchen,可以看到全部文件自动重命名为 kitchen (1)等。

4. 把defined_label改为你自己的测试集分类

5. resized_image = tf.image.resize_images(img_data_jpg, [200, 200])把两个200改成你想resize成的高和宽(高和宽顺序不要弄反了,我这也不确定,如果不放心可以参考上一篇文章 http://blog.csdn.net/zsg2063/article/details/75646394,resize出来看看效果)。

6. 有些地方喜欢把 image_raw_data_jpg = tf.gfile.FastGFile(image_dir, 'rb').read() 里面的rb写成r,我测试过程中r有问题,这里建议写成rb,具体原因没调查。

7. 这份代码目前仅限于jpeg文件,png文件没研究。

阅读全文
1 0