您的位置:首页 > 其它

Tensorflow深度学习之二十:CIFAR-10数据集介绍

2017-12-13 16:55 656 查看
一、CIFAR-10

CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。

CIFAR-10数据集被划分成了5个训练的batch和1个测试的batch,每个batch均包含10000张图片。测试集batch的图片是从每个类别中随机挑选的1000张图片组成的,训练集batch以随机的顺序包含剩下的50000张图片。不过一些训练集batch可能出现包含某一类图片比其他类的图片数量多的情况。训练集batch包含来自每一类的5000张图片,一共50000张训练图片。

下图显示的是数据集的类,以及每一类中随机挑选的10张图片:



二、CIFAR-10数据集解析

官方给出了多个CIFAR-10数据集的版本,以下是链接:

VersionSizemd5sum
CIFAR-10 python version163 MBc58f30108f718f92721af3b95e74349a
CIFAR-10 Matlab version175 MB70270af85842c9e89bb428ec9976c926
CIFAR-10 binary version (suitable for C programs)162 MBc32a1d4ab5d03f1284b67883e8d87530
此处我们下载python版本。

下载完成后,解压,得到如下目录结构的文件夹:



其中:

名称作用
batches.meta程序中不需要使用该文件
data_batch_1训练集的第一个batch,含有10000张图片
data_batch_2训练集的第二个batch,含有10000张图片
data_batch_3训练集的第三个batch,含有10000张图片
data_batch_4训练集的第四个batch,含有10000张图片
data_batch_5训练集的第五个batch,含有10000张图片
readme.html网页文件,程序中不需要使用该文件
test_batch测试集的batch,含有10000张图片
上述文件结构中,每一个batch文件包含一个python的字典(dict)结构,结构如下:

名称作用
b’data’是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张
b’labels’一个长度为10000的list,对应包含data中每一张图片的label
b’batch_label’这一份batch的名称
b’filenames’一个长度为10000的list,对应包含data中每一张图片的名称
真正重要的两个关键字是data和labels,剩下的两个并不是十分重要。

参考官网给出的方法,获取每个batch文件中的字典信息:

import numpy as np
import pickle

def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict


在字典结构中,每一张图片是以被展开的形式存储(即一张32x32的3通道图片被展开成了3072长度的list),每一个数据的格式为uint8,前1024个数据表示红色通道,接下来的1024个数据表示绿色通道,最后的1024个通道表示蓝色通道。

下面的函数的作用就是提取每一个通道的数据,进行重新排列,最后返回一张32x32的3通道的图片:

def GetPhoto(pixel):
assert len(pixel) == 3072
# 对list进行切片操作,然后reshape
r = pixel[0:1024]; r = np.reshape(r, [32, 32, 1])
g = pixel[1024:2048]; g = np.reshape(g, [32, 32, 1])
b = pixel[2048:3072]; b = np.reshape(b, [32, 32, 1])

photo = np.concatenate([r, g, b], -1)

return photo


下面的函数的作用是提取训练集batch中的数据:

# 按照给出的关键字提取数据
def GetTrainDataByLabel(label):
batch_label = []
labels = []
data = []
filenames = []
for i in range(1, 1+5):
batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'batch_label'])
labels += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'labels']
data.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'data'])
filenames += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'filenames']

data = np.concatenate(data, 0)

label = str.encode(label)
if label == b'data':
array = np.ndarray([len(data), 32, 32, 3], dtype=np.int32)
for i in range(len(data)):
array[i] = GetPhoto(data[i])
return array
pass
elif label == b'labels':
return labels
pass
elif label == b'batch_label':
return batch_label
pass
elif label == b'filenames':
return filenames
pass
else:
raise NameError


下面的代码的作用是提取测试集中的数据:

def GetTestDataByLabel(label):
batch_label = []
filenames = []

batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'batch_label'])
labels = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'labels']
data = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'data']
filenames += unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'filenames']

label = str.encode(label)
if label == b'data':
array = np.ndarray([len(data), 32, 32, 3], dtype=np.int32)
for i in range(len(data)):
array[i] = GetPhoto(data[i])
return array
pass
elif label == b'labels':
return labels
pass
elif label == b'batch_label':
return batch_label
pass
elif label == b'filenames':
return filenames
pass
else:
raise NameError


将上述的代码放在同一个文件中。

完整cifar_data_load.py文件如下:

import numpy as np
import pickle

def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict

def GetPhoto(pixel):
assert len(pixel) == 3072
r = pixel[0:1024]; r = np.reshape(r, [32, 32, 1])
g = pixel[1024:2048]; g = np.reshape(g, [32, 32, 1])
b = pixel[2048:3072]; b = np.reshape(b, [32, 32, 1])

photo = np.concatenate([r, g, b], -1)

return photo

def GetTrainDataByLabel(label):
batch_label = []
labels = []
data = []
filenames = []
for i in range(1, 1+5):
batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'batch_label'])
labels += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'labels']
data.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'data'])
filenames += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d"%i)[b'filenames']

data = np.concatenate(data, 0)

label = str.encode(label)
if label == b'data':
array = np.ndarray([len(data), 32, 32, 3], dtype=np.int32)
for i in range(len(data)):
array[i] = GetPhoto(data[i])
return array
pass
elif label == b'labels':
return labels
pass
elif label == b'batch_label':
return batch_label
pass
elif label == b'filenames':
return filenames
pass
else:
raise NameError

def GetTestDataByLabel(label): batch_label = [] filenames = [] batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'batch_label']) labels = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'labels'] data = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'data'] filenames += unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'filenames'] label = str.encode(label) if label == b'data': array = np.ndarray([len(data), 32, 32, 3], dtype=np.int32) for i in range(len(data)): array[i] = GetPhoto(data[i]) return array pass elif label == b'labels': return labels pass elif label == b'batch_label': return batch_label pass elif label == b'filenames': return filenames pass else: raise NameError


==========================================

2017年12月19日更新

增加了指定使用哪些数据文件的参数filelist

import numpy as np
import pickle
import cv2

def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict

def GetPhoto(pixel):
assert len(pixel) == 3072
r = pixel[0:1024]; r = np.reshape(r, [32, 32, 1])
g = pixel[1024:2048]; g = np.reshape(g, [32, 32, 1])
b = pixel[2048:3072]; b = np.reshape(b, [32, 32, 1])

photo = np.concatenate([r, g, b], -1)

return photo

def getTrainDataByKeyword(keyword, size=(32, 32), normalized=False, filelist=[]):
'''
:param keyword:'data', 'labels', 'batch_label', 'filenames', 表示需要返回的项目
:param size:当keyword 是data时,表示需要返回的图片的尺寸
:param normalized:当keyword是data时,表示是否需要归一化
:param filelist:一个list, 表示需要使用的文件对象,仅1, 2, 3, 4, 5是有效的,其他数字无效
:return:需要返回的数据对象。'data'表示需要返回像素数据。'labels'表示需要返回标签数据。'batch_label'表示需要返回文件标签数据。'filenames'表示需要返回文件的文件名信息。
'''

keyword = str.encode(keyword)

assert keyword in [b'data', b'labels', b'batch_label', b'filenames']
assert type(filelist) is list and len(filelist) != 0
assert type(normalized) is bool
assert type(size) is tuple

files = []
for i in filelist:
if 1 <= i <= 5 and i not in files:
files.append(i)

if len(files) == 0:
raise ValueError("No valid input files!")

if keyword == b'data':
data = []
for i in files:
data.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d" % i)[b'data'])
data = np.concatenate(data, 0)
if normalized == False:
array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
for i in range(len(data)):
array[i] = cv2.resize(GetPhoto(data[i]), size)
return array
else:
array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
for i in range(len(data)):
array[i] = cv2.resize(GetPhoto(data[i]), size)/255
return array
pass
if keyword == b'labels':
labels = []
for i in files:
labels += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d" % i)[b'labels']
return labels
pass
elif keyword == b'batch_label':
batch_label = []
for i in files:
batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d" % i)[b'batch_label'])
return batch_label
pass
elif keyword == b'filenames':
filenames = []
for i in files:
filenames += unpickle("cifar-10-python/cifar-10-batches-py/data_batch_%d" % i)[b'filenames']
return filenames
pass
pass

def getTestDataByKeyword(keyword, size=(32, 32), normalized=False):
'''
:param keyword:'data', 'labels', 'batch_label', 'filenames', 表示需要返回的项目
:param size:当keyword 是data时,表示需要返回的图片的尺寸
:param normalized:当keyword是data时,表示是否需要归一化
:return:需要返回的数据对象。'data'表示需要返回像素数据。'labels'表示需要返回标签数据。'batch_label'表示需要返回文件标签数据。'filenames'表示需要返回文件的文件名信息。
'''
keyword = str.encode(keyword)

assert keyword in [b'data', b'labels', b'batch_label', b'filenames']
assert type(size) is tuple
assert type(normalized) is bool

batch_label = []
filenames = []

batch_label.append(unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'batch_label'])
labels = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'labels']
data = unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'data']
filenames += unpickle("cifar-10-python/cifar-10-batches-py/test_batch")[b'filenames']

label = str.encode(keyword)
if label == b'data':
if normalized == False:
array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
for i in range(len(data)):
array[i] = cv2.resize(GetPhoto(data[i]), size)
return array
else:
array = np.ndarray([len(data), size[0], size[1], 3], dtype=np.float32)
for i in range(len(data)):
array[i] = cv2.resize(GetPhoto(data[i]), size) / 255
return array
pass
elif label == b'labels':
return labels
pass
elif label == b'batch_label':
return batch_label
pass
elif label == b'filenames':
return filenames
pass
else:
raise NameError
pass
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息