tensorflow读取SVHN数据集转为TFrecords格式

来源:互联网 发布:单片机经典项目 编辑:程序博客网 时间:2024/06/04 23:36

       这里默认将python脚本文件和svhn数据集放在同一目录下,FLAGS.directory参数可以指定数据集的目录,由于svhn没有validation数据集,因此将train分割一部分出来作为validation。

注释:num_sample_size默认为10000,训练的样本数可以自己设,这里我设置20000,前15000作为训练,后5000作为验证

import argparseimport osimport sysimport tensorflow as tffrom scipy.io import loadmatdef data_set(data_dir, name, num_sample_size=10000):    filename = os.path.join(data_dir, name + '_32x32.mat')    if not os.path.isfile(filename):        raise ValueError('Please supply a the file')    #filename = os.path.join(data_dir,"train_32x32.mat")    datadict = loadmat(filename)    train_x = datadict['X']    train_x = train_x.transpose((3, 0, 1, 2))    train_y = datadict['y'].flatten()    train_y[train_y==10]= 0    train_x = train_x[:num_sample_size]    train_y = train_y[:num_sample_size]    return train_x,train_ydef _int64_feature(value):  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to_tfrecords(images, labels, fileName):  num_examples, rows, cols, depth = images.shape    print('Writing', fileName)  writer = tf.python_io.TFRecordWriter(fileName)  for index in range(num_examples):    image_raw = images[index].tostring()    example = tf.train.Example(features=tf.train.Features(feature={        'height': _int64_feature(rows),        'width': _int64_feature(cols),        'depth': _int64_feature(depth),        'label': _int64_feature(int(labels[index])),        'image_raw': _bytes_feature(image_raw)}))    writer.write(example.SerializeToString())  writer.close()def split_dataset(train_x, train_y, validation_size):    return (train_x[:-validation_size],            train_y[:-validation_size],            train_x[-validation_size:],            train_y[-validation_size:])def main(unused_argv):  train_x, train_y = data_set(FLAGS.directory, 'train')  test_x, test_y = data_set(FLAGS.directory, 'test')  train_x, train_y, valid_x, valid_y = split_dataset(          train_x, train_y, FLAGS.validation_size)    trainFileName = os.path.join(FLAGS.directory, 'train.tfrecords')  validationFileName = os.path.join(FLAGS.directory, 'validation.tfrecords')  testFileName = os.path.join(FLAGS.directory, 'test.tfrecords')  # Convert to Examples and write the result to TFRecords.  convert_to_tfrecords(train_x, train_y, trainFileName,num_sample_size=20000)  convert_to_tfrecords(test_x, test_y, validationFileName)  convert_to_tfrecords(valid_x, valid_y, testFileName)   print('over')  if __name__ =='__main__':   parser = argparse.ArgumentParser()  parser.add_argument(          '--directory',          type=str,          default='.',          help='Directory to download data files and write the converted result'  )  parser.add_argument(          '--validation_size',          type=int,          default=5000,          help="""\          Number of examples to separate from the training data for the validation          set.\          """  )  FLAGS, unparsed = parser.parse_known_args()  tf.app.run(main=main,argv=[sys.argv[0]] + unparsed)


0 0