tensorFlow数据输入

来源:互联网 发布:linux 文件上传权限 编辑:程序博客网 时间:2024/06/06 01:47

一、TensorFlow 数据的输入

  1. preloaded data : 预加载数据
  2. Feeding : pyhthon 产生数据,再把数据喂给后端
  3. Reading from file : 从文件中直接读取

Preload:

import tensorflow as tf#define Graphx1 = tf.constant([2,3,4])x2 = tf.constant([4,0,1])y  = tf.add(x1,x2)#define sessionwith tf.Session() as sess:    print sess.run(y)

Feeding:

import tensorflow as tf#define Graphx1 = tf.placeholder(tf.int16)x2 = tf.placeholder(tf.int16)y  = tf.add(x1,x2)# python generate datali1 = [2,3,4]li2 = [4,0,1]#sessionwith tf.Session() as sess:    sess.run(y,feed_dict = {x1:li1,x2:li2})

Read from file :

一个典型的文件读取管线会包含下列的步骤:

  1. 文件名列表 [“file0”,”file1” …]
  2. 可配置的文件名乱序
  3. 可配置的最大迭代次数
  4. 文件名对列
  5. 针对输入文件格式的阅读器
  6. 记录的解析器
  7. 可配置的预处理器
  8. 样本队列
import tensorflow as tffilename = os.path.join(os.getcwd(), file_name)#产生文件队列, 可配置文件名和乱序filename_queue = tf.train.string_input_producer([filenames],shuffle = True)reader = tf.TextLineReader(skip_header_lines = 1)#每一次 read 都会从文件中读取一行内容。key , value = reader.read(filename_queue)record_defaults = [[0],[0],[0],[0]]#会解析这一行内容并将其转换为张量列表decoded = tf.decode_csv(value , record_defaults = record_defaults)with Session as sess:    #coordinate这是负责在收到任何关闭信号的时候让所有的线程都知道    coord = tf.train.Coordinator()    #在调用run或者eval去执行read之前,必须先调用一下将文件名填充到队列中,否则read将会堵塞    threads = tf.train.start_queue_runners(coord = coord)

批处理

在数据输入管线的末端,我们需要有另一个队列来执行输入样本的训练,评价和推理,因此我们使用

一下语句对队列中的样本进行乱序的处理。

 #min_after_dequeue defines how big a buffer we will randomly sample  #   from -- bigger means better shuffling but slower start up and more  #   memory used.  # capacity must be larger than min_after_dequeue and the amount larger  #   determines the maximum we will prefetch.  Recommendation:  #   min_after_dequeue + (num_threads + a small safety margin) * batch_sizemin_after_dequeue = 10000capacity = min_after_dequeue + 3*batch_sizetf.train.shuffle_batch(decoded , batch_size = batch_size , capacity = capcity ,min_after_dequeue = batch_size)
原创粉丝点击