tensorflow使用range_input_producer多线程读取数据

来源:互联网 发布:ui设计网站知乎 编辑:程序博客网 时间:2024/06/08 18:12

原文:http://blog.csdn.net/lyg5623/article/details/69387917


先放关键代码:

[python] view plain copy
  1. i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()  
  2. inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])  
原理解析:

第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;

0,1,2,0,1,2

队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。

如果num_epochs不指定,则队列内容是这样子:

0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。

下面是完整的演示代码。

数据文件test.txt内容:

[html] view plain copy
  1. 1  
  2. 2  
  3. 3  
  4. 4  
  5. 5  
  6. 6  
  7. 7  
  8. 8  
  9. 9  
  10. 10  
  11. 11  
  12. 12  
  13. 13  
  14. 14  
  15. 15  
  16. 16  
  17. 17  
  18. 18  
  19. 19  
  20. 20  
  21. 21  
  22. 22  
  23. 23  
  24. 24  
  25. 25  
  26. 26  
  27. 27  
  28. 28  
  29. 29  
  30. 30  
  31. 31  
  32. 32  
  33. 33  
  34. 34  
  35. 35  
main.py内容:

[python] view plain copy
  1. import tensorflow as tf  
  2. import codecs  
  3.   
  4. BATCH_SIZE = 6  
  5. NUM_EXPOCHES = 5  
  6.   
  7.   
  8. def input_producer():  
  9.     array = codecs.open("test.txt").readlines()  
  10.     array = map(lambda line: line.strip(), array)  
  11.     i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()  
  12.     inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])  
  13.     return inputs  
  14.   
  15.   
  16. class Inputs(object):  
  17.     def __init__(self):  
  18.         self.inputs = input_producer()  
  19.   
  20.   
  21. def main(*args, **kwargs):  
  22.     inputs = Inputs()  
  23.     init = tf.group(tf.initialize_all_variables(),  
  24.                     tf.initialize_local_variables())  
  25.     sess = tf.Session()  
  26.     coord = tf.train.Coordinator()  
  27.     threads = tf.train.start_queue_runners(sess=sess, coord=coord)  
  28.     sess.run(init)  
  29.     try:  
  30.         index = 0  
  31.         while not coord.should_stop() and index<10:  
  32.             datalines = sess.run(inputs.inputs)  
  33.             index += 1  
  34.             print("step: %d, batch data: %s" % (index, str(datalines)))  
  35.     except tf.errors.OutOfRangeError:  
  36.         print("Done traing:-------Epoch limit reached")  
  37.     except KeyboardInterrupt:  
  38.         print("keyboard interrput detected, stop training")  
  39.     finally:  
  40.         coord.request_stop()  
  41.     coord.join(threads)  
  42.     sess.close()  
  43.     del sess  
  44.       
  45. if __name__ == "__main__":  
  46.     main()  

输出:

[html] view plain copy
  1. step: 1, batch data: ['1' '2' '3' '4' '5' '6']  
  2. step: 2, batch data: ['7' '8' '9' '10' '11' '12']  
  3. step: 3, batch data: ['13' '14' '15' '16' '17' '18']  
  4. step: 4, batch data: ['19' '20' '21' '22' '23' '24']  
  5. step: 5, batch data: ['25' '26' '27' '28' '29' '30']  
  6. Done traing:-------Epoch limit reached  

如果range_input_producer去掉参数num_epochs=1,则输出:
[html] view plain copy
  1. step: 1, batch data: ['1' '2' '3' '4' '5' '6']  
  2. step: 2, batch data: ['7' '8' '9' '10' '11' '12']  
  3. step: 3, batch data: ['13' '14' '15' '16' '17' '18']  
  4. step: 4, batch data: ['19' '20' '21' '22' '23' '24']  
  5. step: 5, batch data: ['25' '26' '27' '28' '29' '30']  
  6. step: 6, batch data: ['1' '2' '3' '4' '5' '6']  
  7. step: 7, batch data: ['7' '8' '9' '10' '11' '12']  
  8. step: 8, batch data: ['13' '14' '15' '16' '17' '18']  
  9. step: 9, batch data: ['19' '20' '21' '22' '23' '24']  
  10. step: 10, batch data: ['25' '26' '27' '28' '29' '30']  

有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:

[html] view plain copy
  1. InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6  
  2.      [[Node: Slice = Slice[Index=DT_INT32T=DT_STRING_device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]  

错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。