TensorFlow数据读取模块调用过程(cifar10)
2017-06-15 19:52
429 查看
最近在看TensorFlow数据读取模块,有了一点思路,先把读取部分的调用过程写下来,以cifar10为例。
入口 cifar10_train.py
distorted_inputs() 函数执行数据读取
1 --------------> cifar10.py
cifar10.distorted_inputs() 是将 data_dir 定义后再调用 cifar10_input.distorted_inputs()
cifar10_input.distorted_inputs 是执行数据读取的主要函数
2 --------------> cifar10_input.py
cifar10_input.distorted_inputs() 主要有几部分组成
1. 生成文件名队列。使用 tf.train.string_input_producer()函数生成文件名队列,通过调用分支 3 进行具体的调用过程分析。
2. 文件读取与解析。通过在函数 read_cifar10() 中定义了对应文件类型的阅读器及解析器,并通过对应的 read 及 decode 方法得到样本数据,通过调用分支 4 进行具体的调用过程分析。
3. 样本处理(包括裁剪、翻转等)
4. 样本批处理。通过在函数 _generate_image_and_label_batch() 中设置线程数,并调用不同的批处理函数,进行数据的批处理,通过调用分支 5 进行具体的调用过程分析。
===================================================================================================
3 --------------> tensorflow/python/training/input.py
string_input_producer() 函数是将字符串(比如文件名)入队到一个队列中,并且添加该队的 QueueRunner 到当前图的 QUEUE_RUNNER collection 中。
其中有几个主要的参数:
num_epochs: 限制 string_tensor 中字符串入队的次数,如果没有定义的话,就是无限次将 string_tensor 中的字符串入队到队列中。
shuffle: 表示是否乱序,如果是 True, 表示字符串入队到队列中是以乱序的形式。
除了 string_input_producer() 之外还有两个函数,实现不同对象的入队操作
# 将 0 - (limit-1) 的整数入队到队列中
range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
# 将 tensor_list 中各 Tensor 的切片入队到队列中
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
string_input_producer() 会调用 input_producer() 进行具体的操作。
3.1 --------------> tensorflow/python/training/input.py
input_producer() 函数主要做了以下操作:
1. 根据参数 shuffle 和 num_epochs 确定 input_tensor,分别通过分支 3.1.1 和 3.1.2 进行具体的分析。
2. 创建队列及入队操作,分别通过分支 3.1.3 和 3.1.4 进行具体的分析。
3. 创建 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中,通过分支 3.1.5 进行具体的分析。
注意:
queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op)
将队列 q 的操作列表 [enq] 添加到一个 QueueRunner,
这里的操作列表 [enq] 会影响后续训练过程中创建线程的个数。(QueueRunner.create_threads() 函数)
3.1.1 --------------> tensorflow/python/ops/random_ops.py
gen_random_ops.py 文件是由自动生成的,在文档的开头有如下标注:
"""Python wrappers around Brain.
This file is MACHINE GENERATED! Do not edit.
"""
通过 _op_def_lib 调用注册的 OP 操作,具体 OP 的定义及 OP Kernel 的定义均在 C++ 后端。
3.1.1.1 --------------> tensorflow/python/ops/gen_random_ops.py
3.1.1.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
具体 OP 及 OP Kernel 的定义可见上述对应的 C++ 文件
OP(比如入队,出队,读取,解析等操作)均为类似的过程:
通过 gen_xxx_ops.py 的 _op_def_lib 调用 OP 及 OP Kernel, 具体的定义在 C++ 后端。
同类型的不再赘述
tensorflow/core/ops/random_ops.cc
REGISTER_OP("RandomShuffle")
tensorflow/core/kernels/random_shuffle_op.cc
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("RandomShuffle").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RandomShuffleOp<T>);
3.1.2 --------------> tensorflow/python/training/input.py
3.1.3 --------------> tensorflow/python/ops/data_flow_ops.py
3.1.3.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
3.1.3.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("FIFOQueueV2")
tensorflow/core/kernels/fifo_queue_op.cc
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
3.1.4 --------------> tensorflow/python/ops/data_flow_ops.py
3.1.4.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
3.1.4.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueEnqueueManyV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), EnqueueManyOp);
3.1.5 --------------> tensorflow/python/training/queue_runner_impl.py
queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
该句有两步调用,首先创建了一个 QueueRunner, 然后将其添加到图的收集中。
3.1.5.1 --------------> tensorflow/python/training/queue_runner_impl.py
这里注意几个参数:
enqueue_ops: List of enqueue ops to run in threads later. 这是一个入队操作的列表,每个操作会在一个线程中执行。
close_op: Op to close the queue. Pending enqueue ops are preserved. 关闭队列,但保留正在等待的入队操作。
cancel_op: Op to close the queue and cancel pending enqueue ops. 关闭队列,同时取消正在等待的入队操作。
3.1.5.2 --------------> tensorflow/python/training/queue_runner_impl.py
4 --------------> cifar10_input.py
在 read_cifar10 函数了主要是进行了文件的读取,以及解码。
将文件名队列提供给阅读器的 read 方法。阅读器的 read 方法会输出一个 key来表征输入的文件和其中的纪录,同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。(截取自官方文件)
数据读方法和解码:
4.1 --------------> tensorflow/python/ops/io_ops.py
4.1.1 --------------> tensorflow/python/ops/gen_io_ops.py
4.1.1.1 -------------->
tensorflow/core/ops/io_ops.cc
REGISTER_OP("FixedLengthRecordReaderV2")
tensorflow/core/kernels/fixed_length_record_reader_op.cc
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU), FixedLengthRecordReaderOp);
4.2 --------------> tensorflow/python/ops/io_ops.py
4.2.1 --------------> tensorflow/python/ops/gen_io_ops.py
4.2.1.1 -------------->
tensorflow/core/ops/io_ops.cc
REGISTER_OP("ReaderReadV2")
tensorflow/core/kernels/reader_ops.cc
REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp);
数据读取的执行过程涉及到session运行过程,是一个异步的操作。
class ReaderVerbAsyncOpKernel : public AsyncOpKernel{
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
ReaderInterface* reader;
OP_REQUIRES_OK_ASYNC(
context, GetResourceFromContext(context, "reader_handle", &reader),
done);
thread_pool_->Schedule([this, context, reader, done]() {
ComputeWithReader(context, reader);
reader->Unref();
done();
});
}
};
class ReaderReadOp : public ReaderVerbAsyncOpKernel{
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
reader->Read(queue, &key_scalar(), &value_scalar(), context);
}
};
reader->Read()是调用 tensorflow/core/framework/reader_base.cc 中的
void ReaderBase::Read(QueueInterface* queue, string* key, string* value, OpKernelContext* context)
进行数据读取的操作,这部分之后再写。
## 4.3 --------------> tensorflow/python/ops/gen_parsing_ops.py
### 4.3.1 -------------->
tensorflow/core/ops/parsing_ops.cc
REGISTER_OP("DecodeRaw")
tensorflow/core/kernels/decode_raw_op.cc
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("DecodeRaw").Device(DEVICE_CPU).TypeConstraint<type>("out_type"), \
DecodeRawOp<type>)
这里比较特殊的是,decode_raw()是直接定义在 gen_xxx_ops.py 文件中的,不是像其他的 OP 是通过 xxx_ops.py 调用的。
5 --------------> cifar10_input.py
_generate_image_and_label_batch 函数通过 shuffle_batch 函数或 batch 函数进行样本的批处理,可以通过 num_preprocess_threads 设置所需要的线程数。
下边分别看一下 shuffle_batch 函数和 batch 函数的调用过程,其实所做的操作基本是一样的,
区别在于 shuffle_batch 函数要进行乱序处理,所创建的队列是 RandomShuffleQueue, 而 batch 函数创建的是 FIFOQueue。
5.1 --------------> tensorflow/python/training/input.py
5.1.1 --------------> tensorflow/python/training/input.py
这个函数做了以下操作:
1. 创建一个样本队列,RandomShuffleQueue
2. 调用 _enqueue 函数进行入队的相关操作。包括创建一个 QueueRunner, 并将其加入图的集合中。
3. 定义批量出队操作
5.1.1.1 --------------> tensorflow/python/ops/data_flow_ops.py
5.1.1.1.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
5.1.1.1.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("RandomShuffleQueueV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU), RandomShuffleQueueOp);
5.1.1.2 --------------> tensorflow/python/training/input.py
_enqueue 函数定义了样本队列入队的相关操作:
1. 确定入队的方法:是 enqueue 还是 enqueue_many
2. 定义入队 OP, 定义了 threads 个入队操作的列表 enqueue_ops
3. 创建一个 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中
到这里为止就一共有了两个 QueueRunner, 一个是负责入队文件名队列的 QueueRunner, 其入队操作列表只有一个入队操作的[enq], 后续的运行过程只会创建一个线程进行入队操作。另一个是负责入队样本队列的 QueueRunner, 其入队操作列表是有 threads 个入队操作的 enqueue_ops, 后续的运行过程会创建 threads 个线程进行入队操作。
5.1.1.3 --------------> tensorflow/python/ops/data_flow_ops.py
5.1.1.3.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
5.1.1.3.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueDequeueUpToV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), DequeueUpToOp);
5.1.1.4 --------------> tensorflow/python/ops/data_flow_ops.py
5.1.1.4.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
5.1.1.4.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueDequeueUpToV2")
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), DequeueManyOp);
5.2 --------------> tensorflow/python/training/input.py
5.2.1 --------------> tensorflow/python/training/input.py
_batch 和 _shuffle_batch 的区别就在于创建队列时不同,这里的 queue 是在 PaddingFIFOQueue 和 FIFOQueue 中选择一个,而 _shuffle_batch 是 RandomShuffleQueue 类型。
5.2.1.1 --------------> tensorflow/python/training/input.py
==================================================================================================================
==================================================================================================================
以上所有均是创建数据流图过程中涉及的操作,并不涉及数据流图的运行过程。queue_runner.add_queue_runner() 函数添加 QueueRunner 到你的数据流图中。
在运行任何训练步骤(在调用 run 或者 eval 去执行 read)之前,需要调用 tf.train.start_queue_runners 函数去开始它的线程运行入队操作,否则数据流图将一直挂起。 tf.train.start_queue_runner 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个 tf.train.Coordinator ,这样可以在发生错误的情况下正确地关闭这些线程。如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。
tensorflow/python/training/queue_runner_impl.py
for qr in ops.get_collection(collection):
threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, start=start))
这里获取了数据流图中的所有 queue_runners,对于每个 qr 都调用 create_threads() 函数执行创建线程的操作。
queue_runner_impl.py 是使用了 Python 创建线程的方法, python 提供了 threading 模块来实现多线程,导入 threading 模块,这是使用多线程的前提
创建了 ret_threads 系列线程,对该 QueueRunner 中的 _enqueue_ops 的所有 op 分别使用 threading.Thread() 方法创建线程。
threading.Thread() 方法中调用 QueueRunner 的 _run 方法 target=self._run, args 方法对 _run 进行传参,不同的线程所传输的 op 是不同的。入队操作线程的个数等于 _enqueue_ops 的 op 个数。
实际在使用中都会结合 coord 一起使用,来帮助多个线程协同工作,多个线程同步终止。
这一句向 ret_threads 列表中又添加了一个线程,这个线程执行 _cancel_op, 这个操作会关闭队列,同时取消正在等待的入队操作。
最终每个 QueueRunner 中一共有 len(enqueue_ops) + 1 个线程。
在 cifar10 的示例中,一共创建了两个 QueueRunner, 其中一个是在文件名入队函数 input_producer 函数中创建的:
因此在这个 QueueRunner 中执行文件名入队操作线程的个数为 len([enq]), 这里 len([enq]) = 1, 因此文件名入队操作线程为 1 个,加上执行 cancel 操作的 1 个线程,一共有 2 个线程。
另外一个 QueueRunner 是现在 batch/shuffle_batch 函数的 _enqueue 函数中创建的:
enqueue_ops 的定义为:
会创建 threads 个入队操作,添加到 enqueue_ops 列表中
因此在第二个 QueueRunner 中执行样本入队操作线程的个数为 len(enqueue_ops), 这里 len(enqueue_ops) = threads, 因此样本入队操作线程为 threads 个,加上执行 cancel 操作的 1 个线程,一共有 threads + 1 个线程。
入口 cifar10_train.py
distorted_inputs() 函数执行数据读取
def train(): with tf.Graph().as_default(): ...... # Get images and labels for CIFAR-10. # 从二进制文件中读取数据 images, labels images, labels = cifar10.distorted_inputs() # 1 --------------> ......
1 --------------> cifar10.py
cifar10.distorted_inputs() 是将 data_dir 定义后再调用 cifar10_input.distorted_inputs()
cifar10_input.distorted_inputs 是执行数据读取的主要函数
def distorted_inputs(): if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size) # 2 --------------> if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) return images, labels
2 --------------> cifar10_input.py
cifar10_input.distorted_inputs() 主要有几部分组成
1. 生成文件名队列。使用 tf.train.string_input_producer()函数生成文件名队列,通过调用分支 3 进行具体的调用过程分析。
2. 文件读取与解析。通过在函数 read_cifar10() 中定义了对应文件类型的阅读器及解析器,并通过对应的 read 及 decode 方法得到样本数据,通过调用分支 4 进行具体的调用过程分析。
3. 样本处理(包括裁剪、翻转等)
4. 样本批处理。通过在函数 _generate_image_and_label_batch() 中设置线程数,并调用不同的批处理函数,进行数据的批处理,通过调用分支 5 进行具体的调用过程分析。
===================================================================================================
def distorted_inputs(data_dir, batch_size): filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) # Create a queue that produces the filenames to read. # 生成要读取的文件名队列 filename_queue = tf.train.string_input_producer(filenames) # 3 --------------> # Read examples from files in the filename queue. read_input = read_cifar10(filename_queue) # 4 --------------> reshaped_image = tf.cast(read_input.uint8image, tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE # Image processing for training the network. Note the many random # distortions applied to the image. # 为训练网络进行图像处理。注意应用于图像的许多随机失真。 # Randomly crop a [height, width] section of the image. # 随机裁剪图像为 [height,width] 像素大小的图片 distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) # Randomly flip the image horizontally. # 随意地水平翻转图像。 distorted_image = tf.image.random_flip_left_right(distorted_image) # Because these operations are not commutative, consider randomizing # the order their operation. # 因为这些操作是不可交换的,所以请考虑将它们的操作随机化。 # 随机的改变图片的亮度 distorted_image = tf.image.random_brightness(distorted_image, max_delta=63) # 随机的改变图片的对比度 distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8) # Subtract off the mean and divide by the variance of the pixels. # 图像的白化:减去平均值并除以像素的方差,均值与方差的均衡,降低图像明暗、光照差异引起的影响 float_image = tf.image.per_image_standardization(distorted_image) # Set the shapes of tensors. float_image.set_shape([height, width, 3]) read_input.label.set_shape([1]) # Ensure that the random shuffling has good mixing properties. # 确保随机 shuffling 具有良好的混合性能。 min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * min_fraction_of_examples_in_queue) print ('Filling queue with %d CIFAR images before starting to train. ' 'This will take a few minutes.' % min_queue_examples) # Generate a batch of images and labels by building up a queue of examples. # 构造 batch_size 样本集(图像+标签) return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=True) # 5 -------------->
3 --------------> tensorflow/python/training/input.py
string_input_producer() 函数是将字符串(比如文件名)入队到一个队列中,并且添加该队的 QueueRunner 到当前图的 QUEUE_RUNNER collection 中。
其中有几个主要的参数:
num_epochs: 限制 string_tensor 中字符串入队的次数,如果没有定义的话,就是无限次将 string_tensor 中的字符串入队到队列中。
shuffle: 表示是否乱序,如果是 True, 表示字符串入队到队列中是以乱序的形式。
除了 string_input_producer() 之外还有两个函数,实现不同对象的入队操作
# 将 0 - (limit-1) 的整数入队到队列中
range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
# 将 tensor_list 中各 Tensor 的切片入队到队列中
slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
capacity=32, shared_name=None, name=None)
string_input_producer() 会调用 input_producer() 进行具体的操作。
def string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None): with ops.name_scope(name, "input_producer", [string_tensor]) as name: string_tensor = ops.convert_to_tensor(string_tensor, dtype=dtypes.string) with ops.control_dependencies([control_flow_ops.Assert( math_ops.greater(array_ops.size(string_tensor), 0), [not_null_err])]): string_tensor = array_ops.identity(string_tensor) return input_producer(input_tensor=string_tensor, element_shape=[], num_epochs=num_epochs, shuffle=shuffle, seed=seed, capacity=capacity, shared_name=shared_name, name=name, summary_name="fraction_of_%d_full" % capacity, cancel_op=cancel_op) # 3.1 -------------->
3.1 --------------> tensorflow/python/training/input.py
input_producer() 函数主要做了以下操作:
1. 根据参数 shuffle 和 num_epochs 确定 input_tensor,分别通过分支 3.1.1 和 3.1.2 进行具体的分析。
2. 创建队列及入队操作,分别通过分支 3.1.3 和 3.1.4 进行具体的分析。
3. 创建 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中,通过分支 3.1.5 进行具体的分析。
注意:
queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op)
将队列 q 的操作列表 [enq] 添加到一个 QueueRunner,
这里的操作列表 [enq] 会影响后续训练过程中创建线程的个数。(QueueRunner.create_threads() 函数)
def input_producer(input_tensor, element_shape=None, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, summary_name=None, name=None, cancel_op=None): with ops.name_scope(name, "input_producer", [input_tensor]): input_tensor = ops.convert_to_tensor(input_tensor, name="input_tensor") element_shape = input_tensor.get_shape()[1:].merge_with(element_shape) # 是否乱序乱序 if shuffle: input_tensor = random_ops.random_shuffle(input_tensor, seed=seed) # 3.1.1 --------------> # 限制迭代次数 input_tensor = limit_epochs(input_tensor, num_epochs) # 3.1.2 --------------> # 创建队列及入队操作 q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[input_tensor.dtype.base_dtype], shapes=[element_shape], shared_name=shared_name, name=name) # 3.1.3 --------------> enq = q.enqueue_many([input_tensor]) # 3.1.4 --------------> # 创建 QueueRunner,并将其加入图的集合中 queue_runner.add_queue_runner(queue_runner.QueueRunner (q, [enq], cancel_op=cancel_op)) # 3.1.5 --------------> if summary_name is not None: summary.scalar(summary_name, math_ops.cast(q.size(), dtypes.float32) * (1. / capacity)) return q
3.1.1 --------------> tensorflow/python/ops/random_ops.py
gen_random_ops.py 文件是由自动生成的,在文档的开头有如下标注:
"""Python wrappers around Brain.
This file is MACHINE GENERATED! Do not edit.
"""
通过 _op_def_lib 调用注册的 OP 操作,具体 OP 的定义及 OP Kernel 的定义均在 C++ 后端。
def random_shuffle(value, seed=None, name=None): seed1, seed2 = random_seed.get_seed(seed) return gen_random_ops._random_shuffle( value, seed=seed1, seed2=seed2, name=name) # 3.1.1.1 -------------->
3.1.1.1 --------------> tensorflow/python/ops/gen_random_ops.py
def _random_shuffle(value, seed=None, seed2=None, name=None): result = _op_def_lib.apply_op("RandomShuffle", value=value, seed=seed, seed2=seed2, name=name) ##### 3.1.1.1.1 --------------> return result
3.1.1.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
具体 OP 及 OP Kernel 的定义可见上述对应的 C++ 文件
OP(比如入队,出队,读取,解析等操作)均为类似的过程:
通过 gen_xxx_ops.py 的 _op_def_lib 调用 OP 及 OP Kernel, 具体的定义在 C++ 后端。
同类型的不再赘述
tensorflow/core/ops/random_ops.cc
REGISTER_OP("RandomShuffle")
tensorflow/core/kernels/random_shuffle_op.cc
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER( \
Name("RandomShuffle").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RandomShuffleOp<T>);
3.1.2 --------------> tensorflow/python/training/input.py
# Returns tensor `num_epochs` times and then raises an `OutOfRange` error. # 限制tensor的迭代次数 def limit_epochs(tensor, num_epochs=None, name=None): if num_epochs is None: return tensor if num_epochs <= 0: raise ValueError("num_epochs must be > 0 not %d." % num_epochs) with ops.name_scope(name, "limit_epochs", [tensor]) as name: zero64 = constant_op.constant(0, dtype=dtypes.int64) epochs = vs.variable( zero64, name="epochs", trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES]) counter = epochs.count_up_to(num_epochs) with ops.control_dependencies([counter]): return array_ops.identity(tensor, name=name)
3.1.3 --------------> tensorflow/python/ops/data_flow_ops.py
class FIFOQueue(QueueBase): def __init__(self, capacity, dtypes, shapes=None, names=None, shared_name=None, name="fifo_queue"): dtypes = _as_type_list(dtypes) shapes = _as_shape_list(shapes, dtypes) names = _as_name_list(names, dtypes) queue_ref = gen_data_flow_ops._fifo_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, shared_name=shared_name, name=name) ### 3.1.3.1 --------------> super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
3.1.3.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
def _fifo_queue_v2(component_types, shapes=None, capacity=None, container=None, shared_name=None, name=None): result = _op_def_lib.apply_op("FIFOQueueV2", component_types=component_types, shapes=shapes, capacity=capacity, container=container, shared_name=shared_name, name=name) ##### 3.1.3.1.1 --------------> return result
3.1.3.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("FIFOQueueV2")
tensorflow/core/kernels/fifo_queue_op.cc
REGISTER_KERNEL_BUILDER(Name("FIFOQueueV2").Device(DEVICE_CPU), FIFOQueueOp);
3.1.4 --------------> tensorflow/python/ops/data_flow_ops.py
class QueueBase(object): def enqueue_many(self, vals, name=None): with ops.name_scope(name, "%s_EnqueueMany" % self._name, self._scope_vals(vals)) as scope: vals = self._check_enqueue_dtypes(vals) batch_dim = vals[0].get_shape().with_rank_at_least(1)[0] for val, shape in zip(vals, self._shapes): batch_dim = batch_dim.merge_with( val.get_shape().with_rank_at_least(1)[0]) val.get_shape()[1:].assert_is_compatible_with(shape) return gen_data_flow_ops._queue_enqueue_many_v2( self._queue_ref, vals, name=scope ### 3.1.4.1 -------------->
3.1.4.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
def _queue_enqueue_many_v2(handle, components, timeout_ms=None, name=None): result = _op_def_lib.apply_op("QueueEnqueueManyV2", handle=handle, components=components, timeout_ms=timeout_ms, name=name) ##### 3.1.4.1.1 --------------> return result
3.1.4.1.1 --------------> 进入 c++ 后端进行具体 OP 的定义
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueEnqueueManyV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("QueueEnqueueManyV2").Device(DEVICE_CPU), EnqueueManyOp);
3.1.5 --------------> tensorflow/python/training/queue_runner_impl.py
queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
该句有两步调用,首先创建了一个 QueueRunner, 然后将其添加到图的收集中。
3.1.5.1 --------------> tensorflow/python/training/queue_runner_impl.py
这里注意几个参数:
enqueue_ops: List of enqueue ops to run in threads later. 这是一个入队操作的列表,每个操作会在一个线程中执行。
close_op: Op to close the queue. Pending enqueue ops are preserved. 关闭队列,但保留正在等待的入队操作。
cancel_op: Op to close the queue and cancel pending enqueue ops. 关闭队列,同时取消正在等待的入队操作。
class QueueRunner(object): def __init__(self, queue=None, enqueue_ops=None, close_op=None, cancel_op=None, queue_closed_exception_types=None, queue_runner_def=None, import_scope=None): if queue_runner_def: if queue or enqueue_ops: raise ValueError("queue_runner_def and queue are mutually exclusive.") self._init_from_proto(queue_runner_def, import_scope=import_scope) else: self._init_from_args( queue=queue, enqueue_ops=enqueue_ops, close_op=close_op, cancel_op=cancel_op, queue_closed_exception_types=queue_closed_exception_types) self._lock = threading.Lock() self._runs_per_session = weakref.WeakKeyDictionary() self._exceptions_raised = []
3.1.5.2 --------------> tensorflow/python/training/queue_runner_impl.py
def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): ops.add_to_collection(collection, qr)
4 --------------> cifar10_input.py
在 read_cifar10 函数了主要是进行了文件的读取,以及解码。
将文件名队列提供给阅读器的 read 方法。阅读器的 read 方法会输出一个 key来表征输入的文件和其中的纪录,同时得到一个字符串标量, 这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。(截取自官方文件)
def read_cifar10(filename_queue): class CIFAR10Record(object): pass result = CIFAR10Record() label_bytes = 1 # 2 for CIFAR-100 result.height = 32 result.width = 32 result.depth = 3 image_bytes = result.height * result.width * result.depth # Every record consists of a label followed by the image, with a # fixed number of bytes for each. # 每个记录都包含标签信息和图片信息,每个记录都有固定的字节数(3073 = 1 + 3072)。 record_bytes = label_bytes + image_bytes # TensorFlow 使用 tf.FixedLengthRecordReader 读取固定长度格式的数据,与 tf.decode_raw 配合使用 reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) ## 4.1 --------------> result.key, value = reader.read(filename_queue) ## 4.2 --------------> # Convert from a string to a vector of uint8 that is record_bytes long. # 从一个字符串转换为一个 uint8 的向量,即 record_bytes 长。 record_bytes = tf.decode_raw(value, tf.uint8) # The first bytes represent the label, which we convert from uint8->int32. # 采用 tf.strided_slice 方法在 record_bytes 中提取第一个 bytes 作为标签,从 uint8 转换为 int32。 result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) # The remaining bytes after the label represent the image, which we reshape # from [depth * height * width] to [depth, height, width]. # 记录中标签后的剩余字节代表图像,从 label 起,在 record_bytes 中提取 self.image_bytes = 3072 长度为图像, # 从 [depth * height * width] 转化为 [depth,height,width],图片转化成 3*32*32。 depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]), [result.depth, result.height, result.width]) # Convert from [depth, height, width] to [height, width, depth]. # 从 [depth, height, width] 转化为 [height, width, depth],图片转化成 32*32*3。 result.uint8image = tf.transpose(depth_major, [1, 2, 0]) return result
数据读方法和解码:
4.1 --------------> tensorflow/python/ops/io_ops.py
class FixedLengthRecordReader(ReaderBase): # A Reader that outputs fixed-length records from a file. def __init__(self, record_bytes, header_bytes=None, footer_bytes=None, hop_bytes=None, name=None): rr = gen_io_ops._fixed_length_record_reader_v2( record_bytes=record_bytes, header_bytes=header_bytes, footer_bytes=footer_bytes, hop_bytes=hop_bytes, name=name) # 4.1.1 --------------> super(FixedLengthRecordReader, self).__init__(rr)
4.1.1 --------------> tensorflow/python/ops/gen_io_ops.py
def _fixed_length_record_reader_v2(record_bytes, header_bytes=None, footer_bytes=None, hop_bytes=None, container=None, shared_name=None, name=None): result = _op_def_lib.apply_op("FixedLengthRecordReaderV2", record_bytes=record_bytes, header_bytes=header_bytes, footer_bytes=footer_bytes, hop_bytes=hop_bytes, container=container, shared_name=shared_name, name=name) #### 4.1.1.1 --------------> return result
4.1.1.1 -------------->
tensorflow/core/ops/io_ops.cc
REGISTER_OP("FixedLengthRecordReaderV2")
tensorflow/core/kernels/fixed_length_record_reader_op.cc
REGISTER_KERNEL_BUILDER(Name("FixedLengthRecordReaderV2").Device(DEVICE_CPU), FixedLengthRecordReaderOp);
4.2 --------------> tensorflow/python/ops/io_ops.py
class ReaderBase(object): def read(self, queue, name=None): if isinstance(queue, ops.Tensor): queue_ref = queue else: queue_ref = queue.queue_ref if self._reader_ref.dtype == dtypes.resource: return gen_io_ops._reader_read_v2(self._reader_ref, queue_ref, name=name) ## 4.2.1 --------------> else: old_queue_op = gen_data_flow_ops._fake_queue(queue_ref) return gen_io_ops._reader_read(self._reader_ref, old_queue_op, name=name) ## 4.2.2 (类似上)-------------->
4.2.1 --------------> tensorflow/python/ops/gen_io_ops.py
def _reader_read_v2(reader_handle, queue_handle, name=None): result = _op_def_lib.apply_op("ReaderReadV2", reader_handle=reader_handle, queue_handle=queue_handle, name=name) #### 4.2.1.1 --------------> return _ReaderReadV2Output._make(result)
4.2.1.1 -------------->
tensorflow/core/ops/io_ops.cc
REGISTER_OP("ReaderReadV2")
tensorflow/core/kernels/reader_ops.cc
REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp);
数据读取的执行过程涉及到session运行过程,是一个异步的操作。
class ReaderVerbAsyncOpKernel : public AsyncOpKernel{
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
ReaderInterface* reader;
OP_REQUIRES_OK_ASYNC(
context, GetResourceFromContext(context, "reader_handle", &reader),
done);
thread_pool_->Schedule([this, context, reader, done]() {
ComputeWithReader(context, reader);
reader->Unref();
done();
});
}
};
class ReaderReadOp : public ReaderVerbAsyncOpKernel{
void ComputeWithReader(OpKernelContext* context,
ReaderInterface* reader) override {
reader->Read(queue, &key_scalar(), &value_scalar(), context);
}
};
reader->Read()是调用 tensorflow/core/framework/reader_base.cc 中的
void ReaderBase::Read(QueueInterface* queue, string* key, string* value, OpKernelContext* context)
进行数据读取的操作,这部分之后再写。
## 4.3 --------------> tensorflow/python/ops/gen_parsing_ops.py
def decode_raw(bytes, out_type, little_endian=None, name=None): result = _op_def_lib.apply_op("DecodeRaw", bytes=bytes, out_type=out_type, little_endian=little_endian, name=name) ### 4.3.1--------------> return result
### 4.3.1 -------------->
tensorflow/core/ops/parsing_ops.cc
REGISTER_OP("DecodeRaw")
tensorflow/core/kernels/decode_raw_op.cc
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("DecodeRaw").Device(DEVICE_CPU).TypeConstraint<type>("out_type"), \
DecodeRawOp<type>)
这里比较特殊的是,decode_raw()是直接定义在 gen_xxx_ops.py 文件中的,不是像其他的 OP 是通过 xxx_ops.py 调用的。
5 --------------> cifar10_input.py
_generate_image_and_label_batch 函数通过 shuffle_batch 函数或 batch 函数进行样本的批处理,可以通过 num_preprocess_threads 设置所需要的线程数。
下边分别看一下 shuffle_batch 函数和 batch 函数的调用过程,其实所做的操作基本是一样的,
区别在于 shuffle_batch 函数要进行乱序处理,所创建的队列是 RandomShuffleQueue, 而 batch 函数创建的是 FIFOQueue。
def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle): num_preprocess_threads = 16 if shuffle: # 当 shuffle = true 时,打乱了样本的原有顺序,每次从队列中 dequeue 取数据是随机的 images, label_batch = tf.train.shuffle_batch( ## 5.1 --------------> [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples) else: # 当 shuffle = false 时,每次 dequeue 是从队列中按顺序取数据,遵从先入先出的原则 images, label_batch = tf.train.batch( ## 5.2 --------------> [image, label], batch_size=batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * batch_size) # Display the training images in the visualizer. tf.summary.image('images', images) return images, tf.reshape(label_batch, [batch_size])
5.1 --------------> tensorflow/python/training/input.py
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): # Creates batches by randomly shuffling tensors. return _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, keep_input=True, num_threads=num_threads, seed=seed, enqueue_many=enqueue_many, shapes=shapes, allow_smaller_final_batch=allow_smaller_final_batch, shared_name=shared_name, name=name) ### 5.1.1 -------------->
5.1.1 --------------> tensorflow/python/training/input.py
这个函数做了以下操作:
1. 创建一个样本队列,RandomShuffleQueue
2. 调用 _enqueue 函数进行入队的相关操作。包括创建一个 QueueRunner, 并将其加入图的集合中。
3. 定义批量出队操作
def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, keep_input, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `shuffle_batch` and `maybe_shuffle_batch`.""" tensor_list = _as_tensor_list(tensors) with ops.name_scope(name, "shuffle_batch", list(tensor_list) + [keep_input]) as name: tensor_list = _validate(tensor_list) keep_input = _validate_keep_input(keep_input, enqueue_many) tensor_list, sparse_info = _store_sparse_tensors( tensor_list, enqueue_many, keep_input) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) queue = data_flow_ops.RandomShuffleQueue( #### 5.1.1.1 --------------> capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, dtypes=types, shapes=shapes, shared_name=shared_name) _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input) #### 5.1.1.2 (同 5.2.1.2) --------------> full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), dtypes.float32) * (1. / (capacity - min_after_dequeue))) # Note that name contains a '/' at the end so we intentionally do not place # a '/' after %s below. summary_name = ("fraction_over_%d_of_%d_full" % (min_after_dequeue, capacity - min_after_dequeue)) summary.scalar(summary_name, full) if allow_smaller_final_batch: dequeued = queue.dequeue_up_to(batch_size, name=name) #### 5.1.1.3 (同 5.2.1.3) --------------> else: dequeued = queue.dequeue_many(batch_size, name=name) #### 5.1.1.4 (同 5.2.1.4) --------------> dequeued = _restore_sparse_tensors(dequeued, sparse_info) return _as_original_type(tensors, dequeued)
5.1.1.1 --------------> tensorflow/python/ops/data_flow_ops.py
class RandomShuffleQueue(QueueBase): def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None, names=None, seed=None, shared_name=None, name="random_shuffle_queue"): queue_ref = gen_data_flow_ops._random_shuffle_queue_v2( component_types=dtypes, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2, shared_name=shared_name, name=name) ##### 5.1.1.1.1 --------------> super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
5.1.1.1.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
def _random_shuffle_queue_v2(component_types, shapes=None, capacity=None, min_after_dequeue=None, seed=None, seed2=None, container=None, shared_name=None, name=None): result = _op_def_lib.apply_op("RandomShuffleQueueV2", component_types=component_types, shapes=shapes, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, seed2=seed2, container=container, shared_name=shared_name, name=name) ###### 5.1.1.1.1.1 --------------> return result
5.1.1.1.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("RandomShuffleQueueV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("RandomShuffleQueueV2").Device(DEVICE_CPU), RandomShuffleQueueOp);
5.1.1.2 --------------> tensorflow/python/training/input.py
_enqueue 函数定义了样本队列入队的相关操作:
1. 确定入队的方法:是 enqueue 还是 enqueue_many
2. 定义入队 OP, 定义了 threads 个入队操作的列表 enqueue_ops
3. 创建一个 QueueRunner, 并将其加入图的 QUEUE_RUNNER 集合中
到这里为止就一共有了两个 QueueRunner, 一个是负责入队文件名队列的 QueueRunner, 其入队操作列表只有一个入队操作的[enq], 后续的运行过程只会创建一个线程进行入队操作。另一个是负责入队样本队列的 QueueRunner, 其入队操作列表是有 threads 个入队操作的 enqueue_ops, 后续的运行过程会创建 threads 个线程进行入队操作。
def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input): """Enqueue `tensor_list` in `queue`.""" if enqueue_many: enqueue_fn = queue.enqueue_many else: enqueue_fn = queue.enqueue # enqueue_ops 的列表, 有 threads 个 enqueue_op if keep_input.get_shape().ndims == 1: enqueue_ops = [enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads else: enqueue_ops = [_smart_cond( keep_input, lambda: enqueue_fn(tensor_list), control_flow_ops.no_op)] * threads queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
5.1.1.3 --------------> tensorflow/python/ops/data_flow_ops.py
class QueueBase(object): def dequeue_up_to(self, n, name=None): if name is None: name = "%s_DequeueUpTo" % self._name ret = gen_data_flow_ops._queue_dequeue_up_to_v2( ##### 5.1.1.3.1 --------------> self._queue_ref, n=n, component_types=self._dtypes, name=name) op = ret[0].op for output, shape in zip(op.values(), self._shapes): output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) return self._dequeue_return_value(ret)
5.1.1.3.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
def _queue_dequeue_up_to_v2(handle, n, component_types, timeout_ms=None, name=None): result = _op_def_lib.apply_op("QueueDequeueUpToV2", handle=handle, n=n, component_types=component_types, timeout_ms=timeout_ms, name=name) ###### 5.1.1.3.1.1 --------------> return result
5.1.1.3.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueDequeueUpToV2")
tensorflow/core/kernels/queue_ops.cc
REGISTER_KERNEL_BUILDER(Name("QueueDequeueUpToV2").Device(DEVICE_CPU), DequeueUpToOp);
5.1.1.4 --------------> tensorflow/python/ops/data_flow_ops.py
class QueueBase(object): def dequeue_many(self, n, name=None): if name is None: name = "%s_DequeueMany" % self._name ret = gen_data_flow_ops._queue_dequeue_many_v2( ##### 5.1.1.4.1 --------------> self._queue_ref, n=n, component_types=self._dtypes, name=name) op = ret[0].op batch_dim = tensor_shape.Dimension(tensor_util.constant_value(op.inputs[1])) for output, shape in zip(op.values(), self._shapes): output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape)) return self._dequeue_return_value(ret)
5.1.1.4.1 --------------> tensorflow/python/ops/gen_data_flow_ops.py
def _queue_dequeue_many_v2(handle, n, component_types, timeout_ms=None, name=None): result = _op_def_lib.apply_op("QueueDequeueManyV2", handle=handle, n=n, component_types=component_types, timeout_ms=timeout_ms, name=name) ###### 5.1.1.4.1.1 --------------> return result
5.1.1.4.1.1 -------------->
tensorflow/core/ops/data_flow_ops.cc
REGISTER_OP("QueueDequeueUpToV2")
REGISTER_KERNEL_BUILDER(Name("QueueDequeueManyV2").Device(DEVICE_CPU), DequeueManyOp);
5.2 --------------> tensorflow/python/training/input.py
def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): return _batch(tensors, batch_size, keep_input=True, num_threads=num_threads, capacity=capacity, enqueue_many=enqueue_many, shapes=shapes, dynamic_pad=dynamic_pad, allow_smaller_final_batch=allow_smaller_final_batch, shared_name=shared_name, name=name) ### 5.2.1 -------------->
5.2.1 --------------> tensorflow/python/training/input.py
_batch 和 _shuffle_batch 的区别就在于创建队列时不同,这里的 queue 是在 PaddingFIFOQueue 和 FIFOQueue 中选择一个,而 _shuffle_batch 是 RandomShuffleQueue 类型。
def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): """Helper function for `batch` and `maybe_batch`.""" tensor_list = _as_tensor_list(tensors) with ops.name_scope(name, "batch", list(tensor_list) + [keep_input]) as name: tensor_list = _validate(tensor_list) keep_input = _validate_keep_input(keep_input, enqueue_many) (tensor_list, sparse_info) = _store_sparse_tensors( tensor_list, enqueue_many, keep_input) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) # 创建一个队列 queue = _which_queue(dynamic_pad)( #### 5.2.1.1 --------------> capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) # 入队操作 _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input) #### 5.2.1.2 (同 5.1.1.2)--------------> summary.scalar("fraction_of_%d_full" % capacity, math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity)) if allow_smaller_final_batch: dequeued = queue.dequeue_up_to(batch_size, name=name) #### 5.2.1.3 (同 5.1.1.3)--------------> else: dequeued = queue.dequeue_many(batch_size, name=name) #### 5.2.1.4 (同 5.1.1.4)--------------> dequeued = _restore_sparse_tensors(dequeued, sparse_info) return _as_original_type(tensors, dequeued)
5.2.1.1 --------------> tensorflow/python/training/input.py
def _which_queue(dynamic_pad): return (data_flow_ops.PaddingFIFOQueue if dynamic_pad else data_flow_ops.FIFOQueue)
==================================================================================================================
==================================================================================================================
以上所有均是创建数据流图过程中涉及的操作,并不涉及数据流图的运行过程。queue_runner.add_queue_runner() 函数添加 QueueRunner 到你的数据流图中。
在运行任何训练步骤(在调用 run 或者 eval 去执行 read)之前,需要调用 tf.train.start_queue_runners 函数去开始它的线程运行入队操作,否则数据流图将一直挂起。 tf.train.start_queue_runner 这个函数将会启动输入管道的线程,填充样本到队列中,以便出队操作可以从队列中拿到样本。这种情况下最好配合使用一个 tf.train.Coordinator ,这样可以在发生错误的情况下正确地关闭这些线程。如果你对训练迭代数做了限制,那么需要使用一个训练迭代数计数器,并且需要被初始化。
tensorflow/python/training/queue_runner_impl.py
def start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=ops.GraphKeys.QUEUE_RUNNERS): """ 开始 graph 中收集的所有 queue_runners。 这是`add_queue_runner()`的配套方法。 它只是为图形中收集的所有 queue_runners 启动线程。 它返回所有线程的列表。 """ if sess is None: sess = ops.get_default_session() if not sess: raise ValueError("Cannot start queue runners: No default session is " "registered. Use `with sess.as_default()` or pass an " "explicit session to tf.start_queue_runners(sess=sess)") with sess.graph.as_default(): threads = [] for qr in ops.get_collection(collection): threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, start=start)) # --------------------> return threads
for qr in ops.get_collection(collection):
threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, start=start))
这里获取了数据流图中的所有 queue_runners,对于每个 qr 都调用 create_threads() 函数执行创建线程的操作。
class QueueRunner(object): def create_threads(self, sess, coord=None, daemon=False, start=False): """ 创建线程来运行给定会话的入队操作。 此方法需要给定的会话中的 graph 已经启动。它创建一个线程列表,可以选择性启动它们。 在`enqueue_ops`中每个操作都有一个线程。 `coord`参数是一个可选的协调器,线程将用于终止并报告异常。 如果给定了协调器,则当协调器请求停止时,此方法启动一个附加线程来关闭队列。 """ with self._lock: try: if self._runs_per_session[sess] > 0: # Already started: no new threads to return. return [] except KeyError: # We haven't seen this session yet. pass self._runs_per_session[sess] = len(self._enqueue_ops) self._exceptions_raised = [] # 创建一系列执行入队操作的线程,线程的个数等于 _enqueue_ops 的 op 个数 ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord)) for op in self._enqueue_ops] if coord: # 如果使用了 coord 协调器,添加一个执行 _cancel_op 操作的线程 ret_threads.append(threading.Thread(target=self._close_on_stop, args=(sess, self._cancel_op, coord))) for t in ret_threads: if coord: # 注册一个线程到 coord.join, 控制线程终止 coord.register_thread(t) if daemon: # 守护线程 t.daemon = True if start: # 开启线程 t.start() return ret_threads
queue_runner_impl.py 是使用了 Python 创建线程的方法, python 提供了 threading 模块来实现多线程,导入 threading 模块,这是使用多线程的前提
import threading
ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord)) for op in self._enqueue_ops]
创建了 ret_threads 系列线程,对该 QueueRunner 中的 _enqueue_ops 的所有 op 分别使用 threading.Thread() 方法创建线程。
threading.Thread() 方法中调用 QueueRunner 的 _run 方法 target=self._run, args 方法对 _run 进行传参,不同的线程所传输的 op 是不同的。入队操作线程的个数等于 _enqueue_ops 的 op 个数。
实际在使用中都会结合 coord 一起使用,来帮助多个线程协同工作,多个线程同步终止。
ret_threads.append(threading.Thread(target=self._close_on_stop, args=(sess, self._cancel_op, coord)))
这一句向 ret_threads 列表中又添加了一个线程,这个线程执行 _cancel_op, 这个操作会关闭队列,同时取消正在等待的入队操作。
最终每个 QueueRunner 中一共有 len(enqueue_ops) + 1 个线程。
在 cifar10 的示例中,一共创建了两个 QueueRunner, 其中一个是在文件名入队函数 input_producer 函数中创建的:
queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
因此在这个 QueueRunner 中执行文件名入队操作线程的个数为 len([enq]), 这里 len([enq]) = 1, 因此文件名入队操作线程为 1 个,加上执行 cancel 操作的 1 个线程,一共有 2 个线程。
另外一个 QueueRunner 是现在 batch/shuffle_batch 函数的 _enqueue 函数中创建的:
queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
enqueue_ops 的定义为:
if keep_input.get_shape().ndims == 1: enqueue_ops = [enqueue_fn(_select_which_to_enqueue(tensor_list, keep_input))] * threads else: enqueue_ops = [_smart_cond( keep_input, lambda: enqueue_fn(tensor_list), control_flow_ops.no_op)] * threads
会创建 threads 个入队操作,添加到 enqueue_ops 列表中
因此在第二个 QueueRunner 中执行样本入队操作线程的个数为 len(enqueue_ops), 这里 len(enqueue_ops) = threads, 因此样本入队操作线程为 threads 个,加上执行 cancel 操作的 1 个线程,一共有 threads + 1 个线程。
相关文章推荐
- TensorFlow 数据读取模块调用过程(inception)
- TensorFlow学习之CNN-Cifar10代码阅读与详解(一):cifar10数据批量读取
- 定义公共的类调用存储过程获取数据
- 自编的VB6.0调用WinAPI的模块(整合了许多函数和过程)
- silverlight动态读取txt文件/解析json数据/调用wcf示例
- asp.net 读取SAP数据(rfc形式全过程)
- 存储过程返回的多结果集数据,ado 访问调用
- 在.net中如何把调用存储过程代码写入数据连接层中
- 通过调用过程把图片文存储到数据
- 程序调用查询数据存储过程的问题
- 存储过程调用DTS包实现大批量数据导入
- ASP.NET调用类连接Access数据库执行sql语句并以GridView方式读取表中数据
- ASP调用存储过程中与SQL对应的数据类型
- abatis实现从页面读取数据添加到数据库全过程
- 在 Linux 下用户空间与内核空间数据交换的方式,第 1 部分: 内核启动参数、模块参数与sysfs、sysctl、系统调用和netlink
- Net 调用SAP RFC接口来读取数据实战纪实
- 存储过程1——读取数据
- ASP调用存储过程中与SQL对应的数据类型
- ajax异步调用,当鼠标点在图片上时,显示一个新层读取数据内容
- 数据做为存储过程参数在JAVA中的调用