Tensorflow学习之tfrecords_reader

来源:互联网 发布:org.apache.spark jar 编辑:程序博客网 时间:2024/06/07 06:22
filename_queue = tf.train.string_input_producer(    [filename], num_epochs=num_epochs)

def read_and_decode(filename_queue):  reader = tf.TFRecordReader()  _, serialized_example = reader.read(filename_queue),
      # Defaults are not specified since both keys are required.      features={          'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64), }) #tf.FixedLenFeature pase FixedLenFeature # Convert from a scalar string tensor (whose single string has # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape # [mnist.IMAGE_PIXELS]. image = tf.decode_raw(features['image_raw'], tf.uint8) image.set_shape([mnist.IMAGE_PIXELS]) # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this # example, and the next step expects the image to be flattened # into a vector, we don't bother. # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 # Convert label from a scalar uint8 tensor to an int32 scalar. label = tf.cast(features['label'], tf.int32) return image, label
0 0