tensorflow tf.train.batch之数据批量读取
来源:互联网 发布:怎么通过mac地址查ip 编辑:程序博客网 时间:2024/05/18 17:03
在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇博文的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次。于是我设置之后出现以下的报错:
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@input_producer/input_producer/limit_epochs/epochs"], limit=2, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]
找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:
import pandas as pdimport numpy as npimport tensorflow as tfdef generate_data(): num = 25 label = np.asarray(range(0, num)) images = np.random.random([num, 5]) print('label size :{}, image size {}'.format(label.shape, images.shape)) return images,labeldef get_batch_data(): label, images = generate_data() input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=2) image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return image_batch,label_batchimages,label = get_batch_data()sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())#就是这一行coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try: while not coord.should_stop(): i,l = sess.run([images,label]) print(i) print(l)except tf.errors.OutOfRangeError: print('Done training')finally: coord.request_stop()coord.join(threads)sess.close()
阅读全文
0 0
- tensorflow tf.train.batch之数据批量读取
- 关于Tensorflow中的tf.train.batch函数
- tensorflow学习——tf.floor与tf.train.batch
- tf.train.batch()
- tensorflow学习——tfreader格式,队列读取数据tf.train.shuffle_batch()
- Tensorflow:tf.train.SyncReplicasOptimizer
- tensorflow tf.train.SummaryWriter()
- Tensorflow的模型保存和读取tf.train.Saver
- tf.train.batch和tf.train.shuffle_batch的用法
- tf.train.batch和tf.train.shuffle_batch的用法
- tf.train.batch和tf.train.shuffle_batch的理解
- tf.train.batch()和tf.train.shuffle_batch()函数
- tensorflow 1.0之tf.train.Saver 文档翻译
- 【Tensorflow】tf.train.AdamOptimizer函数
- tensorflow关于tf.train.Saver()
- TensorFlow中tf.train.exponential_decay的用法
- tensorflow学习——批量读取数据
- tensorflow数据读取之tfrecords
- 机器学习数据标准和归一化
- java 数值类型和字符串的相互转换
- sql-server基础知识四(视图和索引)
- Shell编程(Shell Script)
- 1042. 字符统计(20)
- tensorflow tf.train.batch之数据批量读取
- DOM扩展
- linux-Centos7安装python3并与python2共存
- 红楼解梦五--饥饿疗法
- hdu 3938 Portal ( 离线并查集)
- PAT乙级1007. 素数对猜想 (20)
- 区间求差 hihocoder 1305
- 在c++中,有哪4个与类型转换相关的关键字,这些关键字各有什么特点,应该在什么场合下使用?
- 理解数据库 1NF 2NF 3NF BCNF