您的位置:首页 > 移动开发

the basic approach to read dataset(TFRecord) with iterator in Tensorflow

2018-03-25 15:58 501 查看
1. the three steps for reading datasets  1) define the constructor method of dataset;  2) define the iterator;  3) to obtain the data tensor from iterator by using get_next method.For example :
import  tensorflow as tf

input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()
y = x*x
with tf.Session() as sess:
for i in range(len(input_data)):
print(sess.run(y))
Then, TextLineDataset() function can be used to reading data by a line, which is usually used to process the task in natural language analysis. there is given a example as following
input_files = ['D:/path/to/flowers/input_file2.txt','D:/path/to/flowers/input_file2.txt']
Dataset = tf.data.TextLineDataset(input_files)
iterator = Dataset.make_one_shot_iterator()
x = iterator.get_next()
with tf.Session() as sess:
# To return a string tensor, which represents a line in a file
for i in range(25) :
print(sess.run(x))
the input_files can be create with more than one a txt file, it's like a string array, which means the dataset can be created with serveral files
the basic approach to read data in TFRecord format:
import  tensorflow as tf
# define a approach to decode TFRecord file
def parser(record):
features = tf.parse_single_example(
record,
features={
'image_raw':tf.FixedLenFeature([],tf.string),
'pixels':tf.FixedLenFeature([],tf.int64),
'label':tf.FixedLenFeature([],tf.int64)
})
decoded_images = tf.decode_raw(features['image_raw'],tf.uint8)
retyped_images = tf.cast(decoded_images, tf.float32)

images = tf.reshape(retyped_images, [784])
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)
return images, labels, pixels

# make a dataset  from  TFRecord files , it can provide serveral files here.
input_files = ["D:/path/1output_test.tfrecords", ]
dataset = tf.data.TFRecordDataset(input_files)

# map() refers to a function  that  decode each piece of data with parser method in a dataset
dataset = dataset.map(parser)

# define a iterator to iterate the data-set
iterator = dataset.make_one_shot_iterator()

# to obtain data
image, label, pixels = iterator.get_next()

with tf.Session() as sess:
# the while can iterate all data without the exactly size of dataset
while True:
try:
x, y, z  = sess.run([image, label, pixels])
print(y, z)
except tf.errors.OutOfRangeError:
break

'''
input_files = tf.placeholder(tf.string)
dataset = tf.data. TFRecordDataset(input_files)
dataset = dataset.map(parser)
iterator = dataset.make_initializable_iterator()
image, label, pixels = iterator.get_next()

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐