您的位置:首页 > 编程语言 > Python开发

用python读取cifar-10与cifar-100图像数据

2017-08-08 08:43 483 查看


版权声明:本文为博主原创文章,未经博主允许不得转载。

有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。

关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址

下载后将其解压,如路径为: /xxx/cifar-10-batches-py/

代码很简单没有写注释,读取代码如下:

[python]
view plain
copy

import cPickle  
import numpy as np  
import os  
  
class Cifar10DataReader():  
    def __init__(self,cifar_folder,onehot=True):  
        self.cifar_folder=cifar_folder  
        self.onehot=onehot  
        self.data_index=1  
        self.read_next=True  
        self.data_label_train=None  
        self.data_label_test=None  
        self.batch_index=0  
          
    def unpickle(self,f):  
        fo = open(f, 'rb')  
        d = cPickle.load(fo)  
        fo.close()  
        return d  
      
    def next_train_data(self,batch_size=100):  
        assert 10000%batch_size==0,"10000%batch_size!=0"  
        rdata=None  
        rlabel=None  
        if self.read_next:  
            f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))  
            print 'read: %s'%f  
            dic_train=self.unpickle(f)  
            self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9  
            np.random.shuffle(self.data_label_train)  
              
            self.read_next=False  
            if self.data_index==5:  
                self.data_index=1  
            else:   
                self.data_index+=1  
          
        if self.batch_index<len(self.data_label_train)//batch_size:  
            #print self.batch_index  
            datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
            self.batch_index+=1  
            rdata,rlabel=self._decode(datum,self.onehot)  
        else:  
            self.batch_index=0  
            self.read_next=True  
            return self.next_train_data(batch_size=batch_size)  
              
        return rdata,rlabel  
      
    def _decode(self,datum,onehot):  
        rdata=list();rlabel=list()  
        if onehot:  
            for d,l in datum:  
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
                hot=np.zeros(10)  
                hot[int(l)]=1  
                rlabel.append(hot)  
        else:  
            for d,l in datum:  
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
                rlabel.append(int(l))  
        return rdata,rlabel  
              
    def next_test_data(self,batch_size=100):  
        if self.data_label_test is None:  
            f=os.path.join(self.cifar_folder,"test_batch")  
            print 'read: %s'%f  
            dic_test=self.unpickle(f)  
            data=dic_test['data']  
            labels=dic_test['labels']#0~9  
            self.data_label_test=zip(data,labels)  
          
        np.random.shuffle(self.data_label_test)  
        datum=self.data_label_test[0:batch_size]  
          
        return self._decode(datum,self.onehot)  
  
if __name__=="__main__":  
    dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")  
    import matplotlib.pyplot as plt  
    d,l=dr.next_test_data()  
    print np.shape(d),np.shape(l)  
    plt.imshow(d[0])  
    plt.show()  
    for i in xrange(600):  
        d,l=dr.next_train_data(batch_size=100)  
        print np.shape(d),np.shape(l)  
   

cifar-100的数据读取(测试和cifar-10一样就不写了,这里面有coarse_labels,即:大类别,需要的话可以自己添加)

[python]
view plain
copy

import cPickle  
import numpy as np  
import os  
  
class Cifar100DataReader():  
    def __init__(self,cifar_folder,onehot=True):  
        self.cifar_folder=cifar_folder  
        self.onehot=onehot  
        self.data_label_train=None  
        self.data_label_test=None  
        self.batch_index=0  
        f=os.path.join(self.cifar_folder,"train")  
        print 'read: %s'%f  
        dic_train=unpickle(f)  
        self.data_label_train=zip(dic_train['data'],dic_train['fine_labels'])#label 0~99  
        np.random.shuffle(self.data_label_train)  
          
          
    def next_train_data(self,batch_size=100):  
        """ 
        cifar100 data content: 
            { 
            "coarse_labels":[0,...,19],#0~19 super category 
            "filenames":["volcano_s_000012.png",...], 
            "batch_label":"", 
            "fine_labels":[0,1...99]#0~99 category 
            } 
        return list of numpy arrays [na,...,na] with specific batch_size 
                na: N dimensional numpy array  
        """  
          
        if self.batch_index<len(self.data_label_train)//batch_size:  
            #print self.batch_index  
            datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
            self.batch_index+=1  
            return self._decode(datum,self.onehot)  
        else:  
            self.batch_index=0  
            np.random.shuffle(self.data_label_train)  
            datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
            self.batch_index+=1  
            return self._decode(datum,self.onehot)  
              
      
    def _decode(self,datum,onehot):  
        rdata=list();rlabel=list()  
        if onehot:  
            for d,l in datum:  
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
                hot=np.zeros(100)  
                hot[int(l)]=1  
                rlabel.append(hot)  
        else:  
            for d,l in datum:  
                rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
                rlabel.append(int(l))  
        return rdata,rlabel  
              
    def next_test_data(self,batch_size=100):  
        ''''' 
        return list of numpy arrays [na,...,na] with specific batch_size 
                na: N dimensional numpy array  
        '''  
        if self.data_label_test is None:  
            f=os.path.join(self.cifar_folder,"test")  
            print 'read: %s'%f  
            dic_test=unpickle(f)  
            data=dic_test['data']  
            #print len(dic_test["coarse_labels"])  
            #print len(dic_test["filenames"])  
            labels=dic_test['fine_labels']#0~99  
            self.data_label_test=zip(data,labels)  
              
        np.random.shuffle(self.data_label_test)  
        datum=self.data_label_test[0:batch_size]  
          
        return self._decode(datum,self.onehot)  
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: