您的位置:首页 > 其它

TFRecord数据集

2017-10-24 19:51 344 查看
tensorflow标准的读取数据格式:TFRecord

       可以允许你讲任意的数据转换为TensorFlow所支持的格式, 这种方法可以使TensorFlow的数据集更容易与网络应用架构相匹配。这种建议的方法就是使用TFRecords文件,TFRecords文件包含了
tf.train.Example

协议内存块(protocol buffer)(协议内存块包含了字段
Features
)。你可以写一段代码获取你的数据, 将数据填入到
Example
协议内存块(protocol
buffer),将协议内存块序列化为一个字符串, 并且通过
tf.python_io.TFRecordWriter
class
写入到TFRecords文件。

1.生成TFRecord文件

(1)获取图片数据

(2)填入Example。

example = tf.train.Example(features = tf.train.Features(feature = {
"label":_int64_feature(index),
"img_raw":_bytes_feature(img_raw)
}))
(3)要写入到文件中,先定义writer:
writer = tf.python_io.TFRecordWriter("./xx.tfrecords 要生成的tfrecord文件名")
将协议内存块序列化为一个字符串, 并且通过
tf.python_io.TFRecordWriter
写入到TFRecords文件。
writer.write(example.SerializeToString()) #协议内存块转换为字符串,再用writer写入TFRecord中

2.读取TFRecord文件

tensorflow读取数据时,将文件名列表交给
tf.train.string_input_producer
函数
.
string_input_producer
来生成一个先入先出的队列, 文件阅读器会需要它来读取数据。
string_input_producer
提供的可配置参数来设置文件名乱序和最大的训练迭代数,
QueueRunner
会为每次迭代(epoch)将所有的文件名加入文件名队列中,

如果
shuffle=True
的话, 会对文件名进行乱序处理。这一过程是比较均匀的,因此它可以产生均衡的文件名队列。
即 将文件列表交给该函数,该函数生成队列,文件阅读器从队列中读数据。根据文件的格式选择阅读器的种类,每个阅读器都有对应的read方法,
read(文件名队列),返回字符串标量,返回的量可以被解析器解析,变成张量。

从TFRecords文件中读取数据, 可以使用[code]tf.TFRecordReader
tf.parse_single_example
解析器。
这个
parse_single_example
操作可以将
Example
协议内存块(protocol buffer)解析为张量
[/code]

(1)定义队列
filename_queue = tf.train.string_input_producer(["catVSdog_train.tfrecords"])
(2)读数据,定义阅读器,使用相应的read方法。
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
(3)解析返回的量
im_features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
(4)获取数据
image = tf.decode_raw(im_features['img_raw'],tf.uint8)
image = tf.reshape(image,[128,128,3])
label = tf.cast(im_features['label'],tf.int32)

取batch,乱序(shufflle_batch)和不乱序(.batch),这样下面image,label= sess.run(image_batch,label_batch)
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=30, capacity=2000,
min_after_dequeue=1000)
(5)创建线程并使用
QueueRunner
对象来预取 的模板
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#操作
coord.request_stop()
coord.join(threads)

with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
#模板
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
#
for i in range(20):
example, l = sess.run([image, label])
img = Image.fromarray(example, 'RGB')
img.save(image_path +"\\"+ str(i) + '_''Label_' + str(l) + '.jpg')
print(example, l)
#模板
coord.request_stop()
coord.join(threads)


参考:
 (1)tensorflow数据读取:http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/reading_data.html
(2)数据集制作: http://www.cnblogs.com/upright/p/6136265.html https://www.2cto.com/kf/201702/604326.html https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/how_tos/reading_data/convert_to_records.py https://github.com/kevin28520/My-TensorFlow-tutorials/blob/master/03%20TFRecord/notMNIST_input.py

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: