7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
来源:互联网 发布:天津网络推广优化 编辑:程序博客网 时间:2024/05/17 23:27
前言
在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步:
1. 把样本数据写入TFRecords二进制文件
2. 从队列中读取
TFRecords二进制文件,能够更好的利用内存,更方便的移动和复制,并且不需要单独的标记文件
下面官网给出的,对mnist文件进行操作的code,具体代码请参考:tensorflow-master\tensorflow\examples\how_tos\reading_data\convert_to_records.py
CODE
源码与解析
解析主要在注释里
import tensorflow as tfimport osimport argparseimport sysos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'#1.0 生成TFRecords 文件from tensorflow.contrib.learn.python.learn.datasets import mnistFLAGS = None# 编码函数如下:def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def convert_to(data_set, name): """Converts a dataset to tfrecords.""" images = data_set.images labels = data_set.labels num_examples = data_set.num_examples if images.shape[0] != num_examples: raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], num_examples)) rows = images.shape[1] # 28 cols = images.shape[2] # 28 depth = images.shape[3] # 1. 是黑白图像,所以是单通道 filename = os.path.join(FLAGS.directory, name + '.tfrecords') print('Writing', filename) writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): image_raw = images[index].tostring() # 写入协议缓存区,height,width,depth,label编码成int64类型,image_raw 编码成二进制 example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(rows), 'width': _int64_feature(cols), 'depth': _int64_feature(depth), 'label': _int64_feature(int(labels[index])), 'image_raw': _bytes_feature(image_raw)})) writer.write(example.SerializeToString()) # 序列化为字符串 writer.close()def main(unused_argv): # Get the data. data_sets = mnist.read_data_sets(FLAGS.directory, dtype=tf.uint8, reshape=False, validation_size=FLAGS.validation_size) # Convert to Examples and write the result to TFRecords. convert_to(data_sets.train, 'train') convert_to(data_sets.validation, 'validation') convert_to(data_sets.test, 'test')if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--directory', type=str, default='MNIST_data/', help='Directory to download data files and write the converted result' ) parser.add_argument( '--validation_size', type=int, default=5000, help="""\ Number of examples to separate from the training data for the validation set.\ """ ) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
运行结果
打印输出
Extracting MNIST_data/train-images-idx3-ubyte.gzExtracting MNIST_data/train-labels-idx1-ubyte.gzExtracting MNIST_data/t10k-images-idx3-ubyte.gzExtracting MNIST_data/t10k-labels-idx1-ubyte.gzWriting MNIST_data/train.tfrecordsWriting MNIST_data/validation.tfrecordsWriting MNIST_data/test.tfrecords
文件
相关
- argparse是python用于解析命令行参数和选项的标准模块,用于代替已经过时的optparse模块。argparse模块的作用是用于解析命令行参数,详情请参见这里:python中的argparse模块:http://blog.csdn.net/fontthrone/article/details/76735591
- 把样本数据写入TFRecords二进制文件 : http://blog.csdn.net/fontthrone/article/details/76727412
- TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据:http://blog.csdn.net/fontthrone/article/details/76727466
- TensorFlow笔记(基础篇):加载数据之从队列中读取:http://blog.csdn.net/fontthrone/article/details/76728083
阅读全文
1 0
- 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
- TensorFlow .tfrecords训练文件生成、使用
- Tensorflow 生成tfrecords
- tensorflow将CSV文件转为TFrecords文件
- TensorFlow基础5:TFRecords文件的存储与读取讲解及代码实现
- TFRecords 文件的生成和读取
- 生成tfrecords文件(29)---《深度学习》
- tensorflow中tfrecords文件的save和read
- TensorFlow读取tfrecords数据
- Tensorflow 构建 TFrecords
- 生成TFRecords文件代码(最终版,亲测可用)
- tensorflow数据读取之tfrecords
- tensorflow中tfrecords使用介绍
- Tensorflow分批量读取tfrecords
- Tensorflow使用笔记(2): 如何构建TFRecords并进行Mini Batch训练
- tensorflow:使用tfrecords时的注意事项
- tensorflow系列(4)tfrecords的使用
- tensorflow中tfrecords格式的读写
- Struts2 官方教程:使用标签
- Qt学习笔记day01
- Windows下搭建Apache, PHP, MySQL (试了一下, 靠谱, 写得非常清晰, 赞一个!)
- [SMOJ1762&2136]放假/假期
- 用 TestPMD 测试 DPDK 性能和功能
- 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
- POJ 2230 Watchcow (欧拉路径 dfs 邻接表)
- 给自己的信
- 分类算法之决策树(Decision tree)
- Python3.6安装NLTK
- find命令
- 客户端状态的存储空间——Session
- keepalived+nginx实现高可用(三)
- IO流学习总结