tensorflow读取SVHN数据集转为TFrecords格式
来源:互联网 发布:单片机经典项目 编辑:程序博客网 时间:2024/06/04 23:36
这里默认将python脚本文件和svhn数据集放在同一目录下,FLAGS.directory参数可以指定数据集的目录,由于svhn没有validation数据集,因此将train分割一部分出来作为validation。
注释:num_sample_size默认为10000,训练的样本数可以自己设,这里我设置20000,前15000作为训练,后5000作为验证
import argparseimport osimport sysimport tensorflow as tffrom scipy.io import loadmatdef data_set(data_dir, name, num_sample_size=10000): filename = os.path.join(data_dir, name + '_32x32.mat') if not os.path.isfile(filename): raise ValueError('Please supply a the file') #filename = os.path.join(data_dir,"train_32x32.mat") datadict = loadmat(filename) train_x = datadict['X'] train_x = train_x.transpose((3, 0, 1, 2)) train_y = datadict['y'].flatten() train_y[train_y==10]= 0 train_x = train_x[:num_sample_size] train_y = train_y[:num_sample_size] return train_x,train_ydef _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to_tfrecords(images, labels, fileName): num_examples, rows, cols, depth = images.shape print('Writing', fileName) writer = tf.python_io.TFRecordWriter(fileName) for index in range(num_examples): image_raw = images[index].tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(rows), 'width': _int64_feature(cols), 'depth': _int64_feature(depth), 'label': _int64_feature(int(labels[index])), 'image_raw': _bytes_feature(image_raw)})) writer.write(example.SerializeToString()) writer.close()def split_dataset(train_x, train_y, validation_size): return (train_x[:-validation_size], train_y[:-validation_size], train_x[-validation_size:], train_y[-validation_size:])def main(unused_argv): train_x, train_y = data_set(FLAGS.directory, 'train') test_x, test_y = data_set(FLAGS.directory, 'test') train_x, train_y, valid_x, valid_y = split_dataset( train_x, train_y, FLAGS.validation_size) trainFileName = os.path.join(FLAGS.directory, 'train.tfrecords') validationFileName = os.path.join(FLAGS.directory, 'validation.tfrecords') testFileName = os.path.join(FLAGS.directory, 'test.tfrecords') # Convert to Examples and write the result to TFRecords. convert_to_tfrecords(train_x, train_y, trainFileName,num_sample_size=20000) convert_to_tfrecords(test_x, test_y, validationFileName) convert_to_tfrecords(valid_x, valid_y, testFileName) print('over') if __name__ =='__main__': parser = argparse.ArgumentParser() parser.add_argument( '--directory', type=str, default='.', help='Directory to download data files and write the converted result' ) parser.add_argument( '--validation_size', type=int, default=5000, help="""\ Number of examples to separate from the training data for the validation set.\ """ ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main,argv=[sys.argv[0]] + unparsed)
0 0
- tensorflow读取SVHN数据集转为TFrecords格式
- TensorFlow读取tfrecords数据
- 机器学习: TensorFlow 的数据读取与TFRecords 格式
- tensorflow数据读取之tfrecords
- tensorflow进行SVHN数据实验
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
- SVHN 数据集
- svhn数据集处理
- Tensorflow分批量读取tfrecords
- 制作TFrecords格式数据
- 数据读取之TFRecords
- 读取.tfrecords格式数据集,进行geture的cnn构建、训练、模型保存
- tensorflow将CSV文件转为TFrecords文件
- Tensorflow之构建自己的图片数据集TFrecords
- Tensorflow构建自己的图片数据集TFrecords
- Tensorflow之构建自己的图片数据集TFrecords(精)
- Tensorflow之构建自己的图片数据集TFrecords
- js访问、输出及换行
- 1078. Hashing (25)
- javaScript回调进化史(含源码)
- FragmentManager findFragmentById 返回null
- 用shell脚本来判断Apache是否开启
- tensorflow读取SVHN数据集转为TFrecords格式
- Linux下安装mysql
- list的四种遍历方式,遍历list集合
- 关于Xcode编译性能优化的研究工作总结
- 1079. Total Sales of Supply Chain (25)
- ABI与API的区别
- Codevs 3955 最长严格上升子序列(加强版)
- [leetcode:python]53.Maximum Subarray
- Codeforces 801C 二分法