机器学习: TensorFlow 的数据读取与TFRecords 格式
2017-03-22 11:24
609 查看
最近学习tensorflow,发现其读取数据的方式看起来有些不同,所以又重新系统地看了一下文档,总得来说,tensorflow 有三种主流的数据读取方式:
1) 传送 (feeding): Python 可以在程序的运行过程中,将数据传送进定义好的 tensor 变量中
2) 从文件读取 (reading from files): 一个输入流从文件中直接读取数据
3) 预加载数据 (preloaded data): 这个很好理解,就是将所有的数据一次性全部读进内存里。
对于第三种方式,在数据量小的时候,是非常高效的,但是如果数据量很大的时候,这种方法显然非常耗内存,所以在数据量很大的时候,一般选择第二种读取方式,即从文件读取。在利用第二种方式读取的时候,我们常常会用到一种 TFRecords 的格式来保存读取的文件。TFRecords 是一种二进制文件。可以在TensorFlow 中方便的进行各种存取操作以及预处理。
我们先来看看,如何将一张图片转换成字符流
接下来,我们看看,如何生成 TFRecords 文件:
最后,我们看看如何从 TFrecords 文件中读数据,并且做批处理:
参考来源:
http://codecloud.net/16485.html
http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
https://www.tensorflow.org/programmers_guide/reading_data
1) 传送 (feeding): Python 可以在程序的运行过程中,将数据传送进定义好的 tensor 变量中
2) 从文件读取 (reading from files): 一个输入流从文件中直接读取数据
3) 预加载数据 (preloaded data): 这个很好理解,就是将所有的数据一次性全部读进内存里。
对于第三种方式,在数据量小的时候,是非常高效的,但是如果数据量很大的时候,这种方法显然非常耗内存,所以在数据量很大的时候,一般选择第二种读取方式,即从文件读取。在利用第二种方式读取的时候,我们常常会用到一种 TFRecords 的格式来保存读取的文件。TFRecords 是一种二进制文件。可以在TensorFlow 中方便的进行各种存取操作以及预处理。
我们先来看看,如何将一张图片转换成字符流
import os import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import skimage.io as io dir_path = 'Face' file_list = os.listdir(dir_path) print file_list for f in file_list: print (dir_path + os.sep + f) img_1 = io.imread(dir_path + os.sep + file_list[0]) #plt.imshow(img_1, cmap='gray') #plt.show() # 将图像转换成字符 img_str = img_1.tostring() # 将字符流还原成图像 img_rec_vec = np.fromstring(img_str, dtype=np.uint8) img_rec = img_rec_vec.reshape(img_1.shape) #plt.imshow(img_rec, cmap='gray') #plt.show()
接下来,我们看看,如何生成 TFRecords 文件:
def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) tfrecords_filename = 'Face.tfrecords' writer = tf.python_io.TFRecordWriter(tfrecords_filename) for img_path in file_list: img = np.array(io.imread(dir_path + os.sep + img_path)) # 从文件夹里读取图像 # 获取图像的宽和高,图像的维度需要存入 TFRecords 文件中 # 以方便后续的处理 # height = img.shape[0] width = img.shape[1] # 将图像转换成字符流 img_raw = img.tostring() # 将字符流以及图像的尺度信息存入TFRecords 文件 example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(height), 'width': _int64_feature(width), 'image_raw': _bytes_feature(img_raw)})) writer.write(example.SerializeToString()) writer.close()
最后,我们看看如何从 TFrecords 文件中读数据,并且做批处理:
# 可以重新定义图像的宽和高, IMAGE_HEIGHT = 224 IMAGE_WIDTH = 224 # 定义读取与解码函数 def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # 获取 features,包含图像,以及图像宽和高 features = tf.parse_single_example( serialized_example, features={ 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), }) # 获取图像信息 image = tf.decode_raw(features['image_raw'], tf.uint8) height = tf.cast(features['height'], tf.int32) width = tf.cast(features['width'], tf.int32) # 将图像转换成多维数组的形式 image_shape = [height, width, 1] image = tf.reshape(image, image_shape) # 重新定义图像的尺度 image_size_const = tf.constant((IMAGE_HEIGHT, IMAGE_WIDTH, 1), dtype=tf.int32) # Random transformations can be put here: right before you crop images # to predefined size. To get more information look at the stackoverflow # question linked above. # 对图像进行预处理,包括裁剪,增边等 resized_image = tf.image.resize_image_with_crop_or_pad(image=image, target_height=IMAGE_HEIGHT, target_width=IMAGE_WIDTH) return resized_image # filename_queue = tf.train.string_input_producer( [tfrecords_filename], num_epochs=10) # Even when reading in multiple threads, share the filename # queue. train_images = read_and_decode(filename_queue) # 要注意 min_after_dequeue 不能超过 capacity image = tf.train.shuffle_batch([train_images], batch_size=1, capacity=5, num_threads=2, min_after_dequeue=1) # The op for initializing the variables. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # Let's read off 3 batches just for example for i in xrange(1): img = sess.run([image]) img_batch = img[0] img_1 = tf.reshape(img_batch[0, :, :, :], [IMAGE_HEIGHT, IMAGE_WIDTH]) print (img_1.shape) plt.imshow(sess.run(img_1), cmap='gray') # coord.request_stop() # coord.join(threads) plt.show() print 'all is well'
参考来源:
http://codecloud.net/16485.html
http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
https://www.tensorflow.org/programmers_guide/reading_data
相关文章推荐
- Tensorflow学习教程------tfrecords数据格式生成与读取
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
- [TFRecord格式数据]利用TFRecords存储与读取带标签的图片
- tensorflow数据读取之tfrecords
- TensorFlow读取tfrecords数据
- tensorflow读取数据-tfrecord格式
- tensorflow读取SVHN数据集转为TFrecords格式
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
- tensorflow读取数据-tfrecord格式
- 由浅入深之Tensorflow(3)----数据读取之TFRecords
- TensorFlow全新的数据读取方式:Dataset API——tf.data.Dataset
- 制作tensorflow标准数据集即制作.tfrecords格式文件
- Tensorlfow 数据读取之TFRecords
- tensorflow读取数据之CSV格式
- Tensorflow中创建自己的TFRecord格式数据集
- tensorflow读取数据之CSV格式
- TFRecord —— tensorflow 下的统一数据存储格式
- 制作TFrecords格式数据
- tensorflow中的TFRecord格式文件的写入和读取
- TensorFlow制作、读取TFRecord格式数据集