tensorflow框架输入数据格式
来源:互联网 发布:经传证券炒股软件 编辑:程序博客网 时间:2024/05/19 03:45
关于TensorFlow的输入数据格式已经研究了几天了,官方的格式是TFRecords格式,但是依然很困惑,仍然不知道用于自己的数据怎么搞。
下面总结一下使用自己的图片的数据格式代码。
1. 写入TFRecord文件
首先,将自己的图片统一格式,写入到 TFRecord 文件中。其中images保存图片的路径,label为相应图片的标签。代码如下:
#coding=utf-8import tensorflow as tfimport cv2filename = "train.tfrecords" #要存入的文件writer = tf.python_io.TFRecordWriter(filename)for i in range(len(images)): image = tf.gfile.FastGFile(images[i], 'rb').read() #image type is string label = label[i] image_shape = cv2.imread(images[i]).shape width = image_shape[1] height = image_shape[0] example = tf.train.Example( features = tf.train.Features(feature = { 'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])), 'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])), 'height': tf.train.Feature(int64_list = tf.train.Int64List(value = [height])), 'width': tf.train.Feature(int64_list = tf.train.Int64List(value = [width])), }) ) #将Example写入TFRecord文件 writer.write(example.SerializeToString())
以上是将所有图片数据存储到一个 TFRecord文件中,当数据量比较大时,也可以将图片数据写入到多个TFRecord文件中。
2. 读取TFRecord文件
首先读取TFRecord文件,对读取到TFRecords文件进行解码,根据保存的serialize example和feature字典返回feature所对应的值。此时获得的值都是string,
需要进一步解码为所需的数据类型。下面是读取文件的代码。
import tensorflow as tfimport matplotlib.pyplot as pltfilename = "train.tfrecords"filename_queue = tf.train.string_input_producer([filename], shuffle = False)reader = tf.TFRecordReader()# read in serialized example data_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example( serialized_example, features = { 'image' : tf.FixedLenFeature([], tf.string), 'label' : tf.FixedLenFeature([], tf.int64), 'height' : tf.FixedLenFeature([], tf.int64), 'width' : tf.FixedLenFeature([], tf.int64) })image = tf.image.decode_jpeg(features['image'])image = tf.image.convert_image_dtype(image, dtype=tf.float32)label = tf.cast(features['label'], tf.int64)height = tf.cast(features['height'], tf.int32)width = tf.cast(features['width'], tf.int32)image = tf.reshape(image, [height, width, 3])image = tf.image.resize_images(image, [512, 512])image_batch, label_batch = tf.train.shuffle_batch( [image, label], batch_size = 10, capacity = 10, min_after_dequeue = 10)with tf.Session() as sess: tf.initialize_all_variables().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess = sess, coord=coord)
for i in range(TRAINING_ROUDNS): ... coord.request_stop() coord.join(threads) sess.close()
阅读全文
0 0
- tensorflow框架输入数据格式
- tensorflow--TFRecord输入数据格式
- Tensorflow基础:多线程输入数据处理框架
- Tensorflow学习笔记-输入数据处理框架
- 【TensorFlow】数据处理(输入数据处理框架)
- imshow的输入数据格式
- 页面输入的数据格式转换类:BaseAction(常用于Struts框架中)
- Tensorflow持久化原理及数据格式
- 输入框输入数据格式合法性验证
- tensorflow的数据输入
- tensorflow的输入
- tensorflow数据输入
- tensorFlow数据输入
- tensorflow mnist实战笔记(一)了解官方mnist数据格式
- TensorFlow的代码框架
- TensorFlow深度学习框架
- TensorFlow深度学习框架
- TensorFlow深度学习框架
- 使用过滤器对象来对登录成功界面设置权限保护
- Activiti工作流数据库表详细介绍(23张表)
- Map学习
- 带头节点单链表操作
- elasticsearch httpclient认证机制
- tensorflow框架输入数据格式
- 1, unicode,窗口,消息
- 组织在项目管理过程中的影响
- SlidingMenu,Tablayout,ViewPager,Fragment结合
- “NetworkOnMainThreadException”异常
- 面向切面编程(AOP)
- SQL Server 游标基础使用
- spring配置JDBCTemplate
- 面试题06:一串英文数字转换成阿拉伯数字