第五课 Tensorflow TFRecord读取数据

来源:互联网 发布:美工刀片可以过安检吗 编辑:程序博客网 时间:2024/05/17 01:38

分享朋友的机器学习应用案例:使用机器学习实现财富自由www.abuquant.com

虽然,可以使用常用的类型,但是使用tfrecord更好。

  1. protobuf的格式传输更快
  2. 结构统一。相当于屏蔽了底层的数据结构。
import tensorflow as tfimport numpy as npfrom IPython.display import display, HTMLimport matplotlib.pyplot as pltimport pandas as pdplt.rcParams["figure.figsize"] = (20,10)
train_df = pd.read_csv('train.csv')display(train_df.head())
label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 pixel8 … pixel774 pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 pixel781 pixel782 pixel783 0 1 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 2 1 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 3 4 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0 4 0 0 0 0 0 0 0 0 0 0 … 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

label_df = train_df.pop(item='label')train_values = train_df.valuestrain_labels = label_df.valuesdisplay(type(train_values))display(train_values.shape)display(type(train_labels))display(train_labels.shape)
numpy.ndarray(42000, 784)numpy.ndarray(42000,)

Example protobuf:

message Example {  Features features = 1;};message Features {  map<string, Feature> feature = 1;};message Feature {  oneof kind {    BytesList bytes_list = 1;    FloatList float_list = 2;    Int64List int64_list = 3;  }};
# 建立tfrecorder writerwriter = tf.python_io.TFRecordWriter('csv_train.tfrecords')for i in xrange(train_values.shape[0]):    image_raw = train_values[i].tostring()    # build example protobuf    example = tf.train.Example(        features=tf.train.Features(feature={                'image_raw':  tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),                'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))            }))    writer.write(record=example.SerializeToString())writer.close()

从TFRecord中读取数据

reader = tf.TFRecordReader()filename_queue = tf.train.string_input_producer(['csv_train.tfrecords'])_, serialized_record = reader.read(filename_queue)features = tf.parse_single_example(serialized_record,    features={        ## tf.FixedLenFeature return Tensor        ## tf.VarLenFeature return SparseTensor        "image_raw": tf.FixedLenFeature([], tf.string),        "label": tf.FixedLenFeature([], tf.int64),    })images = tf.decode_raw(features['image_raw'], tf.uint8)labels = tf.cast(features['label'], tf.int32)with tf.Session() as session:    session.run(tf.local_variables_initializer())    session.run(tf.global_variables_initializer())    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(sess=session, coord=coord)    for i in xrange(2):        image, label = session.run([images, labels])        display(label)        display(image)        print '-' * 40
1array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)----------------------------------------0array([0, 0, 0, ..., 0, 0, 0], dtype=uint8)----------------------------------------INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.CancelledError'>, Enqueue operation was cancelled     [[Node: input_producer_3/input_producer_3_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer_3, input_producer_3/RandomShuffle)]]
原创粉丝点击