tensorflow tf.train.batch之数据批量读取
2017-10-21 21:10
876 查看
在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇博文的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。然后发现这篇博文说slice_input_producer()这个函数有一个形参num_epochs,通过设置它的值就可以控制全部数据循环输出几次。于是我设置之后出现以下的报错:
找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:
tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value input_producer/input_producer/limit_epochs/epochs [[Node: input_producer/input_producer/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@input_producer/input_producer/limit_epochs/epochs"], limit=2, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer/input_producer/limit_epochs/epochs)]]
找了好久,都不知道为什么会错,于是只好去看看slice_input_producer()函数的源码,结果在源码中发现作者说这个num_epochs如果不是空的话,就是一个局部变量,需要先调用global_variables_initializer()函数初始化。于是我调用了之后,一切就正常了,特此记录下来,希望其他人遇到的时候能够及时找到原因。哈哈,这是笔者第一次通过阅读源码解决了问题,心情还是有点小激动。啊啊,扯远了,上最终成功的代码:
import pandas as pd import numpy as np import tensorflow as tf def generate_data(): num = 25 label = np.asarray(range(0, num)) images = np.random.random([num, 5]) print('label size :{}, image size {}'.format(label.shape, images.shape)) return images,label def get_batch_data(): label, images = generate_data() input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=2) image_batch, label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False) return image_batch,label_batch images,label = get_batch_data() sess = tf.Session() sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer())#就是这一行 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess,coord) try: while not coord.should_stop(): i,l = sess.run([images,label]) print(i) print(l) except tf.errors.OutOfRangeError: print('Done training') finally: coord.request_stop() coord.join(threads) sess.close()
相关文章推荐
- 第五课 Tensorflow TFRecord读取数据
- tensorflowxun训练自己的数据集之从tfrecords读取数据
- Different Readers for different file types(Tensorflow 的几种读取数据的方式)
- tensorflow TFRecords文件的生成和读取的方法
- TensorFlow使用next_batch()读取/tensorflow.python.framework.errors_impl.InvalidArgumentError: Expect 3 fi
- 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集
- 《TensorFlow学习笔记》完美解决 pip3 install tensorflow 没有models库,读取PTB数据
- tensorflow tf.nn.embedding_lookup(embeddings, train_inputs)解释
- tensorflow tf.train.Supervisor作用
- 清新脱俗的TensorFlow CIFAR10例程的代码重构——更简明更快的数据读取、loss accuracy实时输出
- 向MySQL数据库中批量读取数据
- MATLAB中批量读取处理数据文件
- tensorflow slim(TF-Slim)介绍
- [置顶] 【R语言 数据合并】批量读取数据文件合并为一个excel表格
- faster-rcnn tensorflow windows python 训练自己数据
- SQL 读取csv 文件批量插入数据
- 详解Tensorflow数据读取有三种方式(next_batch)
- tensorflow tf.reduce_mean
- fetch bulk collect into 批量效率的读取游标数据【转载】
- 批量读取不同类型数据并存入不同数据库表