您的位置:首页 > 编程语言

TensorFlow学习笔记(二十四)自制TFRecord数据集 读取、显示及代码详解

2017-08-18 10:29 435 查看
在跑通了官网的mnist和cifar10数据之后,笔者尝试着制作自己的数据集,并保存,读入,显示。 TensorFlow可以支持cifar10的数据格式, 也提供了标准的TFRecord 格式,而关于 tensorflow 读取数据, 官网提供了3中方法
1 Feeding: 在tensorflow程序运行的每一步, 用python代码在线提供数据

2 Reader : 在一个计算图(tf.graph)的开始前,将文件读入到流(queue)中

3 在声明tf.variable变量或numpy数组时保存数据。受限于内存大小,适用于数据较小的情况

在本文,主要介绍第二种方法,利用tf.record标准接口来读入文件

准备图片数据

笔者找了2类狗的图片, 哈士奇和吉娃娃, 全部 resize成128 * 128大小

如下图, 保存地址为D:\Python\data\dog



每类中有10张图片





现在利用这2 类 20张图片制作TFRecord文件

制作TFRECORD文件

1 先聊一下tfrecord, 这是一种将图像数据和标签放在一起的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储 等等..

这里注意,tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签

如在本例中,只有0,1 两类

2 先上“制作TFRecord文件”的代码,注释附详解

?
运行完这段代码后,会生成dog_train.tfrecords 文件,如下图



tf.train.Example 协议内存块包含了Features字段,通过feature将图片的二进制数据和label进行统一封装, 然后将example协议内存块转化为字符串, tf.python_io.TFRecordWriter 写入到TFRecords文件中。

读取TFRECORD文件

在制作完tfrecord文件后, 将该文件读入到数据流中。

代码如下

?
注意,feature的属性“label”和“img_raw”名称要和制作时统一 ,返回的img数据和label数据一一对应。返回的img和label是2个 tf 张量,print出来 如下图



显示tfrecord格式的图片

有些时候我们希望检查分类是否有误,或者在之后的网络训练过程中可以监视,输出图片,来观察分类等操作的结果,那么我们就可以session回话中,将tfrecord的图片从流中读取出来,再保存。 紧跟着一开始的代码写:

?
代码运行完后, 从tfrecord中取出的文件被保存了。如下图:



在这里我们可以看到,图片文件名的第一个数字表示在流中的顺序(这里没有用shuffle), 第二个数字则是 每个图片的label,吉娃娃都为0,哈士奇都为1。 由此可见,我们一开始制作tfrecord文件时,图片分类正确。

下面给出一些常见图片处理方式:

# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt

import tensorflow as tf   

import numpy as np

import os

os.getcwd()

image_raw_data = tf.gfile.FastGFile("E:\\testData\\images\\cat.jpg",'rb').read()

with tf.Session() as sess:

    img_data = tf.image.decode_jpeg(image_raw_data)

    

    # 输出解码之后的三维矩阵。

    #print(img_data.eval())

    #print(img_data.get_shape())

    img_data.set_shape([1797, 2673, 3])

    print(img_data.get_shape())

 

#### 2. 打印图片    

with tf.Session() as sess:

    plt.imshow(img_data.eval())

    plt.show()

#### 3. 重新调整图片大小    

with tf.Session() as sess:    

    resized = tf.image.resize_images(img_data, [300, 300], method=0)

    

    # TensorFlow的函数处理图片后存储的数据是float32格式的,需要转换成uint8才能正确打印图片。

    print("Digital type: ", resized.dtype)

    cat = np.asarray(resized.eval(), dtype='uint8')

    # tf.image.convert_image_dtype(rgb_image, tf.float32)

    plt.imshow(cat)

    plt.show()    

    

#### 4. 裁剪和填充图片    

with tf.Session() as sess:    

    croped = tf.image.resize_image_with_crop_or_pad(img_data, 1000, 1000)

    padded = tf.image.resize_image_with_crop_or_pad(img_data, 3000, 3000)

    plt.imshow(croped.eval())

    plt.show()

    plt.imshow(padded.eval())

    plt.show()    

    

#### 5. 截取中间50%的图片    

with tf.Session() as sess:   

    central_cropped = tf.image.central_crop(img_data, 0.5)

    plt.imshow(central_cropped.eval())

    plt.show()    

    

#### 6. 翻转图片    

with tf.Session() as sess:

    # 上下翻转

    #flipped1 = tf.image.flip_up_down(img_data)

    # 左右翻转

    #flipped2 = tf.image.flip_left_right(img_data)

    

    #对角线翻转

    transposed = tf.image.transpose_image(img_data)

    plt.imshow(transposed.eval())

    plt.show()

    

    # 以一定概率上下翻转图片。

    #flipped = tf.image.random_flip_up_down(img_data)

    # 以一定概率左右翻转图片。

    #flipped = tf.image.random_flip_left_right(img_data)    

    

#### 7. 图片色彩调整    

with tf.Session() as sess:     

    # 将图片的亮度-0.5。

    #adjusted = tf.image.adjust_brightness(img_data, -0.5)

    

    # 将图片的亮度-0.5

    #adjusted = tf.image.adjust_brightness(img_data, 0.5)

    

    # 在[-max_delta, max_delta)的范围随机调整图片的亮度。

    adjusted = tf.image.random_brightness(img_data, max_delta=0.5)

    

    # 将图片的对比度-5

    #adjusted = tf.image.adjust_contrast(img_data, -5)

    

    # 将图片的对比度+5

    #adjusted = tf.image.adjust_contrast(img_data, 5)

    

    # 在[lower, upper]的范围随机调整图的对比度。

    #adjusted = tf.image.random_contrast(img_data, lower, upper)

    plt.imshow(adjusted.eval())

    plt.show()    

 

#### 8. 添加色相和饱和度    

with tf.Session() as sess:         

    adjusted = tf.image.adjust_hue(img_data, 0.1)

    #adjusted = tf.image.adjust_hue(img_data, 0.3)

    #adjusted = tf.image.adjust_hue(img_data, 0.6)

    #adjusted = tf.image.adjust_hue(img_data, 0.9)

    

    # 在[-max_delta, max_delta]的范围随机调整图片的色相。max_delta的取值在[0, 0.5]之间。

    #adjusted = tf.image.random_hue(image, max_delta)

    

    # 将图片的饱和度-5。

    #adjusted = tf.image.adjust_saturation(img_data, -5)

    # 将图片的饱和度+5。

    #adjusted = tf.image.adjust_saturation(img_data, 5)

    # 在[lower, upper]的范围随机调整图的饱和度。

    #adjusted = tf.image.random_saturation(img_data, lower, upper)

    

    # 将代表一张图片的三维矩阵中的数字均值变为0,方差变为1。

    #adjusted = tf.image.per_image_whitening(img_data)

    

    plt.imshow(adjusted.eval())

    plt.show()    

    

#### 9. 添加标注框并裁减。    

with tf.Session() as sess:         

    boxes = tf.constant([[[0.05, 0.05, 0.9, 0.7], [0.35, 0.47, 0.5, 0.56]]])

    begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(

        tf.shape(img_data), bounding_boxes=boxes)

    batched = tf.expand_dims(tf.image.convert_image_dtype(img_data, tf.float32), 0)

    image_with_box = tf.image.draw_bounding_boxes(batched, bbox_for_draw)

    

    distorted_image = tf.slice(img_data, begin, size)

    plt.imshow(distorted_image.eval())

    plt.show()    

   
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: