用python读取cifar-10与cifar-100图像数据
2017-08-08 08:43
483 查看
有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。
关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址
下载后将其解压,如路径为: /xxx/cifar-10-batches-py/
代码很简单没有写注释,读取代码如下:
[python]
view plain
copy
import cPickle
import numpy as np
import os
class Cifar10DataReader():
def __init__(self,cifar_folder,onehot=True):
self.cifar_folder=cifar_folder
self.onehot=onehot
self.data_index=1
self.read_next=True
self.data_label_train=None
self.data_label_test=None
self.batch_index=0
def unpickle(self,f):
fo = open(f, 'rb')
d = cPickle.load(fo)
fo.close()
return d
def next_train_data(self,batch_size=100):
assert 10000%batch_size==0,"10000%batch_size!=0"
rdata=None
rlabel=None
if self.read_next:
f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
print 'read: %s'%f
dic_train=self.unpickle(f)
self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9
np.random.shuffle(self.data_label_train)
self.read_next=False
if self.data_index==5:
self.data_index=1
else:
self.data_index+=1
if self.batch_index<len(self.data_label_train)//batch_size:
#print self.batch_index
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
rdata,rlabel=self._decode(datum,self.onehot)
else:
self.batch_index=0
self.read_next=True
return self.next_train_data(batch_size=batch_size)
return rdata,rlabel
def _decode(self,datum,onehot):
rdata=list();rlabel=list()
if onehot:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
hot=np.zeros(10)
hot[int(l)]=1
rlabel.append(hot)
else:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
rlabel.append(int(l))
return rdata,rlabel
def next_test_data(self,batch_size=100):
if self.data_label_test is None:
f=os.path.join(self.cifar_folder,"test_batch")
print 'read: %s'%f
dic_test=self.unpickle(f)
data=dic_test['data']
labels=dic_test['labels']#0~9
self.data_label_test=zip(data,labels)
np.random.shuffle(self.data_label_test)
datum=self.data_label_test[0:batch_size]
return self._decode(datum,self.onehot)
if __name__=="__main__":
dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")
import matplotlib.pyplot as plt
d,l=dr.next_test_data()
print np.shape(d),np.shape(l)
plt.imshow(d[0])
plt.show()
for i in xrange(600):
d,l=dr.next_train_data(batch_size=100)
print np.shape(d),np.shape(l)
cifar-100的数据读取(测试和cifar-10一样就不写了,这里面有coarse_labels,即:大类别,需要的话可以自己添加)
[python]
view plain
copy
import cPickle
import numpy as np
import os
class Cifar100DataReader():
def __init__(self,cifar_folder,onehot=True):
self.cifar_folder=cifar_folder
self.onehot=onehot
self.data_label_train=None
self.data_label_test=None
self.batch_index=0
f=os.path.join(self.cifar_folder,"train")
print 'read: %s'%f
dic_train=unpickle(f)
self.data_label_train=zip(dic_train['data'],dic_train['fine_labels'])#label 0~99
np.random.shuffle(self.data_label_train)
def next_train_data(self,batch_size=100):
"""
cifar100 data content:
{
"coarse_labels":[0,...,19],#0~19 super category
"filenames":["volcano_s_000012.png",...],
"batch_label":"",
"fine_labels":[0,1...99]#0~99 category
}
return list of numpy arrays [na,...,na] with specific batch_size
na: N dimensional numpy array
"""
if self.batch_index<len(self.data_label_train)//batch_size:
#print self.batch_index
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
else:
self.batch_index=0
np.random.shuffle(self.data_label_train)
datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]
self.batch_index+=1
return self._decode(datum,self.onehot)
def _decode(self,datum,onehot):
rdata=list();rlabel=list()
if onehot:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
hot=np.zeros(100)
hot[int(l)]=1
rlabel.append(hot)
else:
for d,l in datum:
rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))
rlabel.append(int(l))
return rdata,rlabel
def next_test_data(self,batch_size=100):
'''''
return list of numpy arrays [na,...,na] with specific batch_size
na: N dimensional numpy array
'''
if self.data_label_test is None:
f=os.path.join(self.cifar_folder,"test")
print 'read: %s'%f
dic_test=unpickle(f)
data=dic_test['data']
#print len(dic_test["coarse_labels"])
#print len(dic_test["filenames"])
labels=dic_test['fine_labels']#0~99
self.data_label_test=zip(data,labels)
np.random.shuffle(self.data_label_test)
datum=self.data_label_test[0:batch_size]
return self._decode(datum,self.onehot)
相关文章推荐
- 用python读取cifar-10与cifar-100图像数据
- Numpy学习(2):将cifar10/100数据文件读入到python数据结构(字典)中
- 如何使用GIST+LIBLINEAR分类器提取CIFAR-10 dataset数据集中图像特征,并用测试数据进行实验
- python使用h5py读取mat文件数据,并保存图像
- opencv-python图像数据的读取
- Python读入CIFAR-10数据库
- CIFAR-10模型训练python版cifar10数据集
- PYTHON读取三维点云球坐标数据并动态生成三维图像与着色
- CIFAR-10和CIFAR-100数据集读取显示
- 【python图像处理】txt文件数据的读取与写入
- TensorFlow学习——CIFAR-10(python实现数据可视化)
- 清新脱俗的TensorFlow CIFAR10例程的代码重构——更简明更快的数据读取、loss accuracy实时输出
- Python3读取深度学习CIFAR-10数据集出现的若干问题解决
- 如何使用GIST+LIBLINEAR分类器提取CIFAR-10 dataset数据集中图像特征,并用测试数据进行实验
- Python实现从excel读取数据并绘制成精美图像
- 深度学习,制作类似cifar10图像数据集
- python pandas 读取.txt .dat 文件时,跳读头文件,并把数据读成数组
- 读取BMP图像每一像素点RGB数据
- python opencv —— io(帧、图像、视频的读取与保存)
- python读取外部数据之excel数据获取及参数说明