用tensorflow DataSet高效加载变长文本输入

来源:互联网 发布:阿里云备案在哪里 编辑:程序博客网 时间:2024/06/18 07:16

DataSet是tensorflow 1.3版本推出的一个high-level的api,在1.3版本还只是处于测试阶段,1.4版本已经正式推出。在网上搜了一遍,发现关于使用DataSet加载文本的资料比较少,官方举的例子只是csv格式的,要求csv文件中所有样本必须具有相同的维度,也就是padding必须在写入csv文件之前做掉,这会增加文件的大小。经过一番折腾试验,这里给出一个DataSet+TFRecords加载变长样本的范例。

首先先把变长的数据写入到TFRecords文件:

def writedata():    xlist = [[1,2,3],[4,5,6,8]]    ylist = [1,2]#这里的数据只是举个例子来说明样本的文本长度不一样,第一个样本3个词标签1,第二个样本4个词标签2    writer = tf.python_io.TFRecordWriter("train.tfrecords")    for i in range(2):        x = xlist[i]        y = ylist[i]        example = tf.train.Example(features=tf.train.Features(feature={            "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),            'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x))        }))        writer.write(example.SerializeToString())    writer.close()


然后用DataSet加载:

feature_names = ['x']def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):    def parse(example_proto):        features = {"x": tf.VarLenFeature(tf.int64),              "y": tf.FixedLenFeature([1], tf.int64)}        parsed_features = tf.parse_single_example(example_proto, features)        x = tf.sparse_tensor_to_dense(parsed_features["x"])        x = tf.cast(x, tf.int32)        x = dict(zip(feature_names, [x]))        y = tf.cast(parsed_features["y"], tf.int32)        return x, y    dataset = (tf.contrib.data.TFRecordDataset(file_path)               .map(parse))    if perform_shuffle:        dataset = dataset.shuffle(buffer_size=256)    dataset = dataset.repeat(repeat_count)    dataset = dataset.padded_batch(2, padded_shapes=({'x':[6]},[1]))  #batch size为2,并且x按maxlen=6来做padding    iterator = dataset.make_one_shot_iterator()    batch_features, batch_labels = iterator.get_next()    return batch_features, batch_labelsnext_batch = my_input_fn('train.tfrecords', True)init = tf.initialize_all_variables()with tf.Session() as sess:    sess.run(init)    for i in range(1):        xs, y =sess.run(next_batch)        print(xs['x'])        print(y)

注意变长的数据TFRecords解析要用VarLenFeature,然后用sparse_tensor_to_dense转换。