您的位置:首页 > 编程语言

theano-xnor-net代码注释9 pylearn2/cifar10.py

2017-06-20 10:48 316 查看
"""
.. todo::

WRITEME
"""
import os
import logging

import numpy
from theano.compat.six.moves import xrange

from pylearn2.datasets import cache, dense_design_matrix
from pylearn2.expr.preprocessing import global_contrast_normalize
from pylearn2.utils import contains_nan
from pylearn2.utils import serial
from pylearn2.utils import string_utils

_logger = logging.getLogger(__name__)

class CIFAR10(dense_design_matrix.DenseDesignMatrix):

"""
.. todo::

WRITEME

Parameters
----------
which_set : str
One of 'train', 'test'
center : WRITEME
rescale : WRITEME
gcn : float, optional
Multiplicative constant to use for global contrast normalization.
No global contrast normalization is applied, if None
start : WRITEME
stop : WRITEME
axes : WRITEME
toronto_prepro : WRITEME
preprocessor : WRITEME
"""

def __init__(self, which_set, center=False, rescale=False, gcn=None,
start=None, stop=None, axes=('b', 0, 1, 'c'),
toronto_prepro = False, preprocessor = None):
# note: there is no such thing as the cifar10 validation set;
# pylearn1 defined one but really it should be user-configurable
# (as it is here)

self.axes = axes

# we define here:
dtype = 'uint8'
ntrain = 50000
nvalid = 0  # artefact, we won't use it
ntest = 10000

# we also expose the following details:
self.img_shape = (3, 32, 32)
#self.img_size存的是图片元素个数,3*32*32=3072
self.img_size = numpy.prod(self.img_shape)
#类别为10,0-9对应标签为label_names
self.n_classes = 10
self.label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']

# prepare loading
#fnames为一个列表,存的是data_batch1~5
fnames = ['data_batch_%i' % i for i in range(1, 6)]
#datasets为一个空字典
datasets = {}
#${PYLEARN2_DATA_PATH}已经在配置pylearn之后存在个人系统目录.bashrc中了,为/home/ubuntu/pylearn2-data/data
# datapath里存/home/ubuntu/pylearn2-data/data/cifar10/cifar-10-batches-py/
datapath = os.path.join(
string_utils.preprocess('${PYLEARN2_DATA_PATH}'),
'cifar10', 'cifar-10-batches-py')
#在data_batch1~5+test_batch六个文件中做循环
for name in fnames + ['test_batch']:
#当前fname为当前操作数据集文件的全路径
fname = os.path.join(datapath, name)
#如果文件不存在,raise一个error
if not os.path.exists(fname):
raise IOError(fname + " was not found. You probably need to "
"download the CIFAR-10 dataset by using the "
"download script in "
"pylearn2/scripts/datasets/download_cifar10.sh "
"or manually from "
"http://www.cs.utoronto.ca/~kriz/cifar.html")
#将当前数据集文件快速缓存进datasets字典中
datasets[name] = cache.datasetCache.cache_file(fname)
#lenx数值就是50000
lenx = int(numpy.ceil((ntrain + nvalid) / 10000.) * 10000)
#设置一个全0矩阵x大小50000×3072,y大小50000×1的np.array
x = numpy.zeros((lenx, self.img_size), dtype=dtype)
y = numpy.zeros((lenx, 1), dtype=dtype)

# load train data
#下载训练集
nloaded = 0
#enumerate返回的是(引索值,当前迭代对象)
for i, fname in enumerate(fnames):
#将括号内信息存入log文件
_logger.info('loading file %s' % datasets[fname])
#从刚加载好的datasets字典中取出当前操作文件数据,存入data,python版本的cifar10本身就是一个字典,
# 所以当前data就是一个字典,字典中有batch_label,labels,data,filenames四种信息
data = serial.load(datasets[fname])
#一个数据集中有10000个图片信息,对应data为10000个3072的np.array,labels对应10000个一维标签,依次取出5个对应训练数据集文件,按照顺序依次存入x与y
x[i * 10000:(i + 1) * 10000, :] = data['data']
y[i * 10000:(i + 1) * 10000, 0] = data['labels']
#以下三行代码运行不到,在迭代完5个文件时候nloaded=50000,小于60000,此时循环就已经退出
nloaded += 10000
if nloaded >= ntrain + nvalid + ntest:
break

# load test data
#加载测试集合
#将括号内信息存入log文件
_logger.info('loading file %s' % datasets['test_batch'])
#加载'test_batch'测试集数据,存入data,前面data中信息已经清空
data = serial.load(datasets['test_batch'])
#重组数据
# process this data
#Xs为一个字典,‘train’关键字中存训练集的50000条图像数据,‘test’关键字中存测试集的10000条图像数据
#Ys为一个字典,‘train’关键字中存训练集的50000个标签,‘test’关键字中存测试集的10000个标签
Xs = {'train': x[0:ntrain],
'test': data['data'][0:ntest]}

Ys = {'train': y[0:ntrain],
'test': data['labels'][0:ntest]}
#which_set为调用CIFAR10类时候传如的参数,选择是[train、test]
#即X为对应[train or test]的图像数据
#y为对应[train or test]的标签

X = numpy.cast['float32'](Xs[which_set])
y = Ys[which_set]
#在该数据集中标签的存储为一个列表list,该行代码是要将label转化为与data一样的ndarray格式
if isinstance(y, list):
y = numpy.asarray(y).astype(dtype)
#如果测试数据集标签数不为10000,重新整理为(y.shape[0], 1)大小
if which_set == 'test':
assert y.shape[0] == 10000
y = y.reshape((y.shape[0], 1))

if center:
X -= 127.5
self.center = center

if rescale:
X /= 127.5
self.rescale = rescale

if toronto_prepro:
assert not center
assert not gcn
X = X / 255.
if which_set == 'test':
other = CIFAR10(which_set='train')
oX = other.X
oX /= 255.
X = X - oX.mean(axis=0)
else:
X = X - X.mean(axis=0)
self.toronto_prepro = toronto_prepro

self.gcn = gcn
if gcn is not None:
gcn = float(gcn)
X = global_contrast_normalize(X, scale=gcn)

if start is not None:
# This needs to come after the prepro so that it doesn't
# change the pixel means computed above for toronto_prepro
assert start >= 0
assert stop > start
assert stop <= X.shape[0]
X = X[start:stop, :]
y = y[start:stop, :]
assert X.shape[0] == y.shape[0]

if which_set == 'test':
assert X.shape[0] == 10000

view_converter = dense_design_matrix.DefaultViewConverter((32, 32, 3),
axes)

super(CIFAR10, self).__init__(X=X, y=y, view_converter=view_converter,
y_labels=self.n_classes)

assert not contains_nan(self.X)

if preprocessor:
preprocessor.apply(self)

def adjust_for_viewer(self, X):
"""
.. todo::

WRITEME
"""
# assumes no preprocessing. need to make preprocessors mark the
# new ranges
rval = X.copy()

# patch old pkl files
if not hasattr(self, 'center'):
self.center = False
if not hasattr(self, 'rescale'):
self.rescale = False
if not hasattr(self, 'gcn'):
self.gcn = False

if self.gcn is not None:
rval = X.copy()
for i in xrange(rval.shape[0]):
rval[i, :] /= numpy.abs(rval[i, :]).max()
return rval

if not self.center:
rval -= 127.5

if not self.rescale:
rval /= 127.5

rval = numpy.clip(rval, -1., 1.)

return rval

def __setstate__(self, state):
super(CIFAR10, self).__setstate__(state)
# Patch old pkls
if self.y is not None and self.y.ndim == 1:
self.y = self.y.reshape((self.y.shape[0], 1))
if 'y_labels' not in state:
self.y_labels = 10

def adjust_to_be_viewed_with(self, X, orig, per_example=False):
"""
.. todo::

WRITEME
"""
# if the scale is set based on the data, display X oring the
# scale determined by orig
# assumes no preprocessing. need to make preprocessors mark
# the new ranges
rval = X.copy()

# patch old pkl files
if not hasattr(self, 'center'):
self.center = False
if not hasattr(self, 'rescale'):
self.rescale = False
if not hasattr(self, 'gcn'):
self.gcn = False

if self.gcn is not None:
rval = X.copy()
if per_example:
for i in xrange(rval.shape[0]):
rval[i, :] /= numpy.abs(orig[i, :]).max()
else:
rval /= numpy.abs(orig).max()
rval = numpy.clip(rval, -1., 1.)
return rval

if not self.center:
rval -= 127.5

if not self.rescale:
rval /= 127.5

rval = numpy.clip(rval, -1., 1.)

return rval

def get_test_set(self):
"""
.. todo::

WRITEME
"""
return CIFAR10(which_set='test', center=self.center,
rescale=self.rescale, gcn=self.gcn,
toronto_prepro=self.toronto_prepro,
axes=self.axes)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: