theano-xnor-net代码注释9 pylearn2/cifar10.py

来源:互联网 发布:php curl 301跳转 编辑:程序博客网 时间:2024/06/02 02:30
""".. todo::    WRITEME"""import osimport loggingimport numpyfrom theano.compat.six.moves import xrangefrom pylearn2.datasets import cache, dense_design_matrixfrom pylearn2.expr.preprocessing import global_contrast_normalizefrom pylearn2.utils import contains_nanfrom pylearn2.utils import serialfrom pylearn2.utils import string_utils_logger = logging.getLogger(__name__)class CIFAR10(dense_design_matrix.DenseDesignMatrix):    """    .. todo::        WRITEME    Parameters    ----------    which_set : str        One of 'train', 'test'    center : WRITEME    rescale : WRITEME    gcn : float, optional        Multiplicative constant to use for global contrast normalization.        No global contrast normalization is applied, if None    start : WRITEME    stop : WRITEME    axes : WRITEME    toronto_prepro : WRITEME    preprocessor : WRITEME    """    def __init__(self, which_set, center=False, rescale=False, gcn=None,                 start=None, stop=None, axes=('b', 0, 1, 'c'),                 toronto_prepro = False, preprocessor = None):        # note: there is no such thing as the cifar10 validation set;        # pylearn1 defined one but really it should be user-configurable        # (as it is here)        self.axes = axes        # we define here:        dtype = 'uint8'        ntrain = 50000        nvalid = 0  # artefact, we won't use it        ntest = 10000        # we also expose the following details:        self.img_shape = (3, 32, 32)        #self.img_size存的是图片元素个数,3*32*32=3072        self.img_size = numpy.prod(self.img_shape)        #类别为10,0-9对应标签为label_names        self.n_classes = 10        self.label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',                            'dog', 'frog', 'horse', 'ship', 'truck']        # prepare loading        #fnames为一个列表,存的是data_batch1~5        fnames = ['data_batch_%i' % i for i in range(1, 6)]        #datasets为一个空字典        datasets = {}        #${PYLEARN2_DATA_PATH}已经在配置pylearn之后存在个人系统目录.bashrc中了,为/home/ubuntu/pylearn2-data/data        # datapath里存/home/ubuntu/pylearn2-data/data/cifar10/cifar-10-batches-py/        datapath = os.path.join(            string_utils.preprocess('${PYLEARN2_DATA_PATH}'),            'cifar10', 'cifar-10-batches-py')        #在data_batch1~5+test_batch六个文件中做循环        for name in fnames + ['test_batch']:            #当前fname为当前操作数据集文件的全路径            fname = os.path.join(datapath, name)            #如果文件不存在,raise一个error            if not os.path.exists(fname):                raise IOError(fname + " was not found. You probably need to "                              "download the CIFAR-10 dataset by using the "                              "download script in "                              "pylearn2/scripts/datasets/download_cifar10.sh "                              "or manually from "                              "http://www.cs.utoronto.ca/~kriz/cifar.html")            #将当前数据集文件快速缓存进datasets字典中            datasets[name] = cache.datasetCache.cache_file(fname)        #lenx数值就是50000        lenx = int(numpy.ceil((ntrain + nvalid) / 10000.) * 10000)        #设置一个全0矩阵x大小50000×3072,y大小50000×1的np.array        x = numpy.zeros((lenx, self.img_size), dtype=dtype)        y = numpy.zeros((lenx, 1), dtype=dtype)        # load train data        #下载训练集        nloaded = 0        #enumerate返回的是(引索值,当前迭代对象)        for i, fname in enumerate(fnames):            #将括号内信息存入log文件            _logger.info('loading file %s' % datasets[fname])            #从刚加载好的datasets字典中取出当前操作文件数据,存入data,python版本的cifar10本身就是一个字典,            # 所以当前data就是一个字典,字典中有batch_label,labels,data,filenames四种信息            data = serial.load(datasets[fname])            #一个数据集中有10000个图片信息,对应data为10000个3072的np.array,labels对应10000个一维标签,依次取出5个对应训练数据集文件,按照顺序依次存入x与y            x[i * 10000:(i + 1) * 10000, :] = data['data']            y[i * 10000:(i + 1) * 10000, 0] = data['labels']            #以下三行代码运行不到,在迭代完5个文件时候nloaded=50000,小于60000,此时循环就已经退出            nloaded += 10000            if nloaded >= ntrain + nvalid + ntest:                break        # load test data        #加载测试集合        #将括号内信息存入log文件        _logger.info('loading file %s' % datasets['test_batch'])        #加载'test_batch'测试集数据,存入data,前面data中信息已经清空        data = serial.load(datasets['test_batch'])        #重组数据        # process this data        #Xs为一个字典,‘train’关键字中存训练集的50000条图像数据,‘test’关键字中存测试集的10000条图像数据        #Ys为一个字典,‘train’关键字中存训练集的50000个标签,‘test’关键字中存测试集的10000个标签        Xs = {'train': x[0:ntrain],              'test': data['data'][0:ntest]}        Ys = {'train': y[0:ntrain],              'test': data['labels'][0:ntest]}        #which_set为调用CIFAR10类时候传如的参数,选择是[train、test]        #即X为对应[train or test]的图像数据        #y为对应[train or test]的标签        X = numpy.cast['float32'](Xs[which_set])        y = Ys[which_set]        #在该数据集中标签的存储为一个列表list,该行代码是要将label转化为与data一样的ndarray格式        if isinstance(y, list):            y = numpy.asarray(y).astype(dtype)        #如果测试数据集标签数不为10000,重新整理为(y.shape[0], 1)大小        if which_set == 'test':            assert y.shape[0] == 10000            y = y.reshape((y.shape[0], 1))        if center:            X -= 127.5        self.center = center        if rescale:            X /= 127.5        self.rescale = rescale        if toronto_prepro:            assert not center            assert not gcn            X = X / 255.            if which_set == 'test':                other = CIFAR10(which_set='train')                oX = other.X                oX /= 255.                X = X - oX.mean(axis=0)            else:                X = X - X.mean(axis=0)        self.toronto_prepro = toronto_prepro        self.gcn = gcn        if gcn is not None:            gcn = float(gcn)            X = global_contrast_normalize(X, scale=gcn)        if start is not None:            # This needs to come after the prepro so that it doesn't            # change the pixel means computed above for toronto_prepro            assert start >= 0            assert stop > start            assert stop <= X.shape[0]            X = X[start:stop, :]            y = y[start:stop, :]            assert X.shape[0] == y.shape[0]        if which_set == 'test':            assert X.shape[0] == 10000        view_converter = dense_design_matrix.DefaultViewConverter((32, 32, 3),                                                                  axes)        super(CIFAR10, self).__init__(X=X, y=y, view_converter=view_converter,                                      y_labels=self.n_classes)        assert not contains_nan(self.X)        if preprocessor:            preprocessor.apply(self)    def adjust_for_viewer(self, X):        """        .. todo::            WRITEME        """        # assumes no preprocessing. need to make preprocessors mark the        # new ranges        rval = X.copy()        # patch old pkl files        if not hasattr(self, 'center'):            self.center = False        if not hasattr(self, 'rescale'):            self.rescale = False        if not hasattr(self, 'gcn'):            self.gcn = False        if self.gcn is not None:            rval = X.copy()            for i in xrange(rval.shape[0]):                rval[i, :] /= numpy.abs(rval[i, :]).max()            return rval        if not self.center:            rval -= 127.5        if not self.rescale:            rval /= 127.5        rval = numpy.clip(rval, -1., 1.)        return rval    def __setstate__(self, state):        super(CIFAR10, self).__setstate__(state)        # Patch old pkls        if self.y is not None and self.y.ndim == 1:            self.y = self.y.reshape((self.y.shape[0], 1))        if 'y_labels' not in state:            self.y_labels = 10    def adjust_to_be_viewed_with(self, X, orig, per_example=False):        """        .. todo::            WRITEME        """        # if the scale is set based on the data, display X oring the        # scale determined by orig        # assumes no preprocessing. need to make preprocessors mark        # the new ranges        rval = X.copy()        # patch old pkl files        if not hasattr(self, 'center'):            self.center = False        if not hasattr(self, 'rescale'):            self.rescale = False        if not hasattr(self, 'gcn'):            self.gcn = False        if self.gcn is not None:            rval = X.copy()            if per_example:                for i in xrange(rval.shape[0]):                    rval[i, :] /= numpy.abs(orig[i, :]).max()            else:                rval /= numpy.abs(orig).max()            rval = numpy.clip(rval, -1., 1.)            return rval        if not self.center:            rval -= 127.5        if not self.rescale:            rval /= 127.5        rval = numpy.clip(rval, -1., 1.)        return rval    def get_test_set(self):        """        .. todo::            WRITEME        """        return CIFAR10(which_set='test', center=self.center,                       rescale=self.rescale, gcn=self.gcn,                       toronto_prepro=self.toronto_prepro,                       axes=self.axes)
原创粉丝点击