tensorflo读取数据之dataset

来源:互联网 发布:淘宝窗帘店 编辑:程序博客网 时间:2024/05/21 09:15

一、Tensorflow读入数据的三种方式
1 Feeding: Python code provides the data when running each step
2 Reading from files: an input pipeline reads the data from files at the beginning of a TensorFlow graph.
3 Preloaded data: a constant or variable in the TensorFlow graph holds all the data (for small data sets).
二、Dataset
1 Dataset API属于第二种方式,使读取数据、复杂的数据格式变换变得更容易
2 A tf.data.Dataset represents a sequence of elements, in which each element contains one or more Tensorobjects. For example, in an image pipeline, an element might be a single training example, with a pair of tensors representing the image data and a label. There are two distinct ways to create a dataset
3.dataset的创建可以来自于tensor,也可以来自于文件
dataset1=tf.data.Dataset.from_tensor_slices,创建来自于tensors的dataset
dataset1= tf.contrib.data.TextLineDataset(src_file)
4 dataset主要API使用,主要做数据转换
tf.data.Dataset.zip
dataset1.map
dataset1.padded_batch
4 iterator创建
dataset1.make_initializable_iterator()
5 使用流程
构造Dataset对象
创建 iterator
三、代码片段

import tensorflow as tf#An element contains one or more tf.Tensor objects, called componentsdataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))print("dataset1.output_types",dataset1.output_types) print("dataset1.output_shapes",dataset1.output_shapes)dataset2 = tf.data.Dataset.from_tensor_slices(   {"a": tf.random_uniform([4]),    "b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})print("dataset2.output_types",dataset2.output_types)print("dataset2.output_shapes",dataset2.output_shapes)dataset3 = tf.data.Dataset.zip((dataset1, dataset2))print("dataset3.output_types",dataset3.output_types)print("dataset3.output_types",dataset3.output_shapes)#which apply a function to each element, the element structure determines the arguments of the functiondataset1 = dataset1.map(lambda x:x+1)#dataset1 = dataset1.padded_batch(2,padded_shapes=[11])dataset1 = dataset1.padded_batch(2,padded_shapes=[None])iterator = dataset1.make_initializable_iterator()next_element = iterator.get_next()init_op = iterator.initializerwith tf.Session() as sess:  print(sess.run(init_op))  print("batched data 1:",sess.run(next_element))  print("batch data 2:",sess.run(next_element))