theano-xnor-net代码注释9 pylearn2/cifar10.py
2017-06-20 10:48
316 查看
""" .. todo:: WRITEME """ import os import logging import numpy from theano.compat.six.moves import xrange from pylearn2.datasets import cache, dense_design_matrix from pylearn2.expr.preprocessing import global_contrast_normalize from pylearn2.utils import contains_nan from pylearn2.utils import serial from pylearn2.utils import string_utils _logger = logging.getLogger(__name__) class CIFAR10(dense_design_matrix.DenseDesignMatrix): """ .. todo:: WRITEME Parameters ---------- which_set : str One of 'train', 'test' center : WRITEME rescale : WRITEME gcn : float, optional Multiplicative constant to use for global contrast normalization. No global contrast normalization is applied, if None start : WRITEME stop : WRITEME axes : WRITEME toronto_prepro : WRITEME preprocessor : WRITEME """ def __init__(self, which_set, center=False, rescale=False, gcn=None, start=None, stop=None, axes=('b', 0, 1, 'c'), toronto_prepro = False, preprocessor = None): # note: there is no such thing as the cifar10 validation set; # pylearn1 defined one but really it should be user-configurable # (as it is here) self.axes = axes # we define here: dtype = 'uint8' ntrain = 50000 nvalid = 0 # artefact, we won't use it ntest = 10000 # we also expose the following details: self.img_shape = (3, 32, 32) #self.img_size存的是图片元素个数,3*32*32=3072 self.img_size = numpy.prod(self.img_shape) #类别为10,0-9对应标签为label_names self.n_classes = 10 self.label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] # prepare loading #fnames为一个列表,存的是data_batch1~5 fnames = ['data_batch_%i' % i for i in range(1, 6)] #datasets为一个空字典 datasets = {} #${PYLEARN2_DATA_PATH}已经在配置pylearn之后存在个人系统目录.bashrc中了,为/home/ubuntu/pylearn2-data/data # datapath里存/home/ubuntu/pylearn2-data/data/cifar10/cifar-10-batches-py/ datapath = os.path.join( string_utils.preprocess('${PYLEARN2_DATA_PATH}'), 'cifar10', 'cifar-10-batches-py') #在data_batch1~5+test_batch六个文件中做循环 for name in fnames + ['test_batch']: #当前fname为当前操作数据集文件的全路径 fname = os.path.join(datapath, name) #如果文件不存在,raise一个error if not os.path.exists(fname): raise IOError(fname + " was not found. You probably need to " "download the CIFAR-10 dataset by using the " "download script in " "pylearn2/scripts/datasets/download_cifar10.sh " "or manually from " "http://www.cs.utoronto.ca/~kriz/cifar.html") #将当前数据集文件快速缓存进datasets字典中 datasets[name] = cache.datasetCache.cache_file(fname) #lenx数值就是50000 lenx = int(numpy.ceil((ntrain + nvalid) / 10000.) * 10000) #设置一个全0矩阵x大小50000×3072,y大小50000×1的np.array x = numpy.zeros((lenx, self.img_size), dtype=dtype) y = numpy.zeros((lenx, 1), dtype=dtype) # load train data #下载训练集 nloaded = 0 #enumerate返回的是(引索值,当前迭代对象) for i, fname in enumerate(fnames): #将括号内信息存入log文件 _logger.info('loading file %s' % datasets[fname]) #从刚加载好的datasets字典中取出当前操作文件数据,存入data,python版本的cifar10本身就是一个字典, # 所以当前data就是一个字典,字典中有batch_label,labels,data,filenames四种信息 data = serial.load(datasets[fname]) #一个数据集中有10000个图片信息,对应data为10000个3072的np.array,labels对应10000个一维标签,依次取出5个对应训练数据集文件,按照顺序依次存入x与y x[i * 10000:(i + 1) * 10000, :] = data['data'] y[i * 10000:(i + 1) * 10000, 0] = data['labels'] #以下三行代码运行不到,在迭代完5个文件时候nloaded=50000,小于60000,此时循环就已经退出 nloaded += 10000 if nloaded >= ntrain + nvalid + ntest: break # load test data #加载测试集合 #将括号内信息存入log文件 _logger.info('loading file %s' % datasets['test_batch']) #加载'test_batch'测试集数据,存入data,前面data中信息已经清空 data = serial.load(datasets['test_batch']) #重组数据 # process this data #Xs为一个字典,‘train’关键字中存训练集的50000条图像数据,‘test’关键字中存测试集的10000条图像数据 #Ys为一个字典,‘train’关键字中存训练集的50000个标签,‘test’关键字中存测试集的10000个标签 Xs = {'train': x[0:ntrain], 'test': data['data'][0:ntest]} Ys = {'train': y[0:ntrain], 'test': data['labels'][0:ntest]} #which_set为调用CIFAR10类时候传如的参数,选择是[train、test] #即X为对应[train or test]的图像数据 #y为对应[train or test]的标签 X = numpy.cast['float32'](Xs[which_set]) y = Ys[which_set] #在该数据集中标签的存储为一个列表list,该行代码是要将label转化为与data一样的ndarray格式 if isinstance(y, list): y = numpy.asarray(y).astype(dtype) #如果测试数据集标签数不为10000,重新整理为(y.shape[0], 1)大小 if which_set == 'test': assert y.shape[0] == 10000 y = y.reshape((y.shape[0], 1)) if center: X -= 127.5 self.center = center if rescale: X /= 127.5 self.rescale = rescale if toronto_prepro: assert not center assert not gcn X = X / 255. if which_set == 'test': other = CIFAR10(which_set='train') oX = other.X oX /= 255. X = X - oX.mean(axis=0) else: X = X - X.mean(axis=0) self.toronto_prepro = toronto_prepro self.gcn = gcn if gcn is not None: gcn = float(gcn) X = global_contrast_normalize(X, scale=gcn) if start is not None: # This needs to come after the prepro so that it doesn't # change the pixel means computed above for toronto_prepro assert start >= 0 assert stop > start assert stop <= X.shape[0] X = X[start:stop, :] y = y[start:stop, :] assert X.shape[0] == y.shape[0] if which_set == 'test': assert X.shape[0] == 10000 view_converter = dense_design_matrix.DefaultViewConverter((32, 32, 3), axes) super(CIFAR10, self).__init__(X=X, y=y, view_converter=view_converter, y_labels=self.n_classes) assert not contains_nan(self.X) if preprocessor: preprocessor.apply(self) def adjust_for_viewer(self, X): """ .. todo:: WRITEME """ # assumes no preprocessing. need to make preprocessors mark the # new ranges rval = X.copy() # patch old pkl files if not hasattr(self, 'center'): self.center = False if not hasattr(self, 'rescale'): self.rescale = False if not hasattr(self, 'gcn'): self.gcn = False if self.gcn is not None: rval = X.copy() for i in xrange(rval.shape[0]): rval[i, :] /= numpy.abs(rval[i, :]).max() return rval if not self.center: rval -= 127.5 if not self.rescale: rval /= 127.5 rval = numpy.clip(rval, -1., 1.) return rval def __setstate__(self, state): super(CIFAR10, self).__setstate__(state) # Patch old pkls if self.y is not None and self.y.ndim == 1: self.y = self.y.reshape((self.y.shape[0], 1)) if 'y_labels' not in state: self.y_labels = 10 def adjust_to_be_viewed_with(self, X, orig, per_example=False): """ .. todo:: WRITEME """ # if the scale is set based on the data, display X oring the # scale determined by orig # assumes no preprocessing. need to make preprocessors mark # the new ranges rval = X.copy() # patch old pkl files if not hasattr(self, 'center'): self.center = False if not hasattr(self, 'rescale'): self.rescale = False if not hasattr(self, 'gcn'): self.gcn = False if self.gcn is not None: rval = X.copy() if per_example: for i in xrange(rval.shape[0]): rval[i, :] /= numpy.abs(orig[i, :]).max() else: rval /= numpy.abs(orig).max() rval = numpy.clip(rval, -1., 1.) return rval if not self.center: rval -= 127.5 if not self.rescale: rval /= 127.5 rval = numpy.clip(rval, -1., 1.) return rval def get_test_set(self): """ .. todo:: WRITEME """ return CIFAR10(which_set='test', center=self.center, rescale=self.rescale, gcn=self.gcn, toronto_prepro=self.toronto_prepro, axes=self.axes)
相关文章推荐
- theano-xnor-net代码注释6 fxp_helper.py
- theano-xnor-net代码注释3 xnor_net.py
- theano-xnor-net代码注释8 xnornet_layers.py
- theano-xnor-net代码注释7 inf_layers.py
- theano-xnor-net代码注释 cifar10_train.py
- theano-xnor-net代码注释4 bnn_utils.py
- theano-xnor-net代码注释5 cifar10_test.py
- theano-xnor-net代码注释2 cnn_utils.py
- CSLA.Net 3.0.5 版本 教学程序,代码附教学注释
- ASP.NET MVC 巧用代码注释做权限控制以及后台导航
- ASP.Net 2.0 窗体身份验证机制-转+自己代码注释示例与更详细的说明(网上转)
- ubuntu14.04安装theano的二进制网络theano-xnor-net
- CSLA.Net 3.0.5 版本 教学程序,代码附教学注释
- ASP.Net 2.0 窗体身份验证机制-转+自己代码注释示例与更详细的说明
- CSLA.Net 3.0.5 版本 教学程序,代码附教学注释
- .NET开发需要养成一种良好的注释代码习惯篇
- CSLA.Net 3.0.5 版本 教学程序,代码附教学注释
- CSLA.Net 3.0.5 版本 教学程序,代码附教学注释
- Asp.Net 读取xml文件中Key的值,并且过滤掉注释内容代码
- C#.NET中代码注释提示