PyTorch代码学习-dataloader

来源:互联网 发布:淘宝上怎么做网络推广 编辑:程序博客网 时间:2024/06/03 17:19
import torch# torch.multiprocessing:本地多进程模块的包装器import torch.multiprocessing as multiprocessing# 采样器模块:无放回抽样:随机抽样(RandomSampling)和系统抽样(SystematicSampling)。# 有放回抽样:随机抽样(RepetitionRandomSampling)。from .sampler import SequentialSampler, RandomSampler, BatchSamplerimport collections'''collections模块在内置数据类型的基础上,提供了几个额外的数据类型:1.namedtuple(): 生成可以使用名字来访问元素内容的tuple子类2.deque: 双端队列,可以快速的从另外一侧追加和推出对象3.Counter: 计数器,主要用来计数4.OrderedDict: 有序字典5.defaultdict: 带有默认值的字典'''import sys# traceback模块被用来跟踪异常返回信息import traceback# threading 多线程控制和处理import threadingfrom torch._six import string_classes# python 版本查询if sys.version_info[0] == 2:    import Queue as queueelse:    import queue_use_shared_memory = False"""Whether to use shared memory in default_collate"""class ExceptionWrapper(object):    "Wraps an exception plus traceback to communicate across threads"    def __init__(self, exc_info):        self.exc_type = exc_info[0]        # format_exception:输出异常栈        # join:字符串操作函数,链接字符        '''a="abcd"        >>> ",".join(a)            'a,b,c,d'        '''        self.exc_msg = "".join(traceback.format_exception(*exc_info))def _worker_loop(dataset, index_queue, data_queue, collate_fn):    global _use_shared_memory    _use_shared_memory = True    torch.set_num_threads(1)    while True:        r = index_queue.get()        if r is None:            data_queue.put(None)            break        idx, batch_indices = r        try:            samples = collate_fn([dataset[i] for i in batch_indices])        except Exception:            data_queue.put((idx, ExceptionWrapper(sys.exc_info())))        else:            data_queue.put((idx, samples))def _pin_memory_loop(in_queue, out_queue, done_event):    while True:        try:            r = in_queue.get()        except Exception:            if done_event.is_set():                return            raise        if r is None:            break        if isinstance(r[1], ExceptionWrapper):            out_queue.put(r)            continue        idx, batch = r        try:            batch = pin_memory_batch(batch)        except Exception:            out_queue.put((idx, ExceptionWrapper(sys.exc_info())))        else:            out_queue.put((idx, batch))numpy_type_map = {    'float64': torch.DoubleTensor,    'float32': torch.FloatTensor,    'float16': torch.HalfTensor,    'int64': torch.LongTensor,    'int32': torch.IntTensor,    'int16': torch.ShortTensor,    'int8': torch.CharTensor,    'uint8': torch.ByteTensor,}def default_collate(batch):    "Puts each data field into a tensor with outer dimension batch size"    if torch.is_tensor(batch[0]):        out = None        if _use_shared_memory:            # If we're in a background process, concatenate directly into a            # shared memory tensor to avoid an extra copy            numel = sum([x.numel() for x in batch])            storage = batch[0].storage()._new_shared(numel)            out = batch[0].new(storage)        return torch.stack(batch, 0, out=out)    elif type(batch[0]).__module__ == 'numpy':        elem = batch[0]        if type(elem).__name__ == 'ndarray':            return torch.stack([torch.from_numpy(b) for b in batch], 0)        if elem.shape == ():  # scalars            py_type = float if elem.dtype.name.startswith('float') else int            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))    elif isinstance(batch[0], int):        return torch.LongTensor(batch)    elif isinstance(batch[0], float):        return torch.DoubleTensor(batch)    elif isinstance(batch[0], string_classes):        return batch    elif isinstance(batch[0], collections.Mapping):        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}    elif isinstance(batch[0], collections.Sequence):        transposed = zip(*batch)        return [default_collate(samples) for samples in transposed]    raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}"                     .format(type(batch[0]))))def pin_memory_batch(batch):    if torch.is_tensor(batch):        return batch.pin_memory()    elif isinstance(batch, string_classes):        return batch    elif isinstance(batch, collections.Mapping):        return {k: pin_memory_batch(sample) for k, sample in batch.items()}    elif isinstance(batch, collections.Sequence):        return [pin_memory_batch(sample) for sample in batch]    else:        return batchclass DataLoaderIter(object):    "Iterates once over the DataLoader's dataset, as specified by the sampler"    def __init__(self, loader):        self.dataset = loader.dataset        self.collate_fn = loader.collate_fn        self.batch_sampler = loader.batch_sampler        self.num_workers = loader.num_workers        self.pin_memory = loader.pin_memory        self.done_event = threading.Event()        self.sample_iter = iter(self.batch_sampler)        if self.num_workers > 0:            self.index_queue = multiprocessing.SimpleQueue()            self.data_queue = multiprocessing.SimpleQueue()            self.batches_outstanding = 0            self.shutdown = False            self.send_idx = 0            self.rcvd_idx = 0            self.reorder_dict = {}            self.workers = [                multiprocessing.Process(                    target=_worker_loop,                    args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn))                for _ in range(self.num_workers)]            for w in self.workers:                w.daemon = True  # ensure that the worker exits on process exit                w.start()            if self.pin_memory:                in_data = self.data_queue                self.data_queue = queue.Queue()                self.pin_thread = threading.Thread(                    target=_pin_memory_loop,                    args=(in_data, self.data_queue, self.done_event))                self.pin_thread.daemon = True                self.pin_thread.start()            # prime the prefetch loop            for _ in range(2 * self.num_workers):                self._put_indices()    def __len__(self):        return len(self.batch_sampler)    def __next__(self):        if self.num_workers == 0:  # same-process loading            indices = next(self.sample_iter)  # may raise StopIteration            batch = self.collate_fn([self.dataset[i] for i in indices])            if self.pin_memory:                batch = pin_memory_batch(batch)            return batch        # check if the next sample has already been generated        if self.rcvd_idx in self.reorder_dict:            batch = self.reorder_dict.pop(self.rcvd_idx)            return self._process_next_batch(batch)        if self.batches_outstanding == 0:            self._shutdown_workers()            raise StopIteration        while True:            assert (not self.shutdown and self.batches_outstanding > 0)            idx, batch = self.data_queue.get()            self.batches_outstanding -= 1            if idx != self.rcvd_idx:                # store out-of-order samples                self.reorder_dict[idx] = batch                continue            return self._process_next_batch(batch)    next = __next__  # Python 2 compatibility    def __iter__(self):        return self    def _put_indices(self):        assert self.batches_outstanding < 2 * self.num_workers        indices = next(self.sample_iter, None)        if indices is None:            return        self.index_queue.put((self.send_idx, indices))        self.batches_outstanding += 1        self.send_idx += 1    def _process_next_batch(self, batch):        self.rcvd_idx += 1        self._put_indices()        if isinstance(batch, ExceptionWrapper):            raise batch.exc_type(batch.exc_msg)        return batch    def __getstate__(self):        # TODO: add limited pickling support for sharing an iterator        # across multiple threads for HOGWILD.        # Probably the best way to do this is by moving the sample pushing        # to a separate thread and then just sharing the data queue        # but signalling the end is tricky without a non-blocking API        raise NotImplementedError("DataLoaderIterator cannot be pickled")    def _shutdown_workers(self):        if not self.shutdown:            self.shutdown = True            self.done_event.set()            for _ in self.workers:                self.index_queue.put(None)    def __del__(self):        if self.num_workers > 0:            self._shutdown_workers()class DataLoader(object):    """    Data loader. Combines a dataset and a sampler, and provides    single- or multi-process iterators over the dataset.    Arguments:        dataset (Dataset): dataset from which to load the data.        batch_size (int, optional): how many samples per batch to load            (default: 1).        shuffle (bool, optional): set to ``True`` to have the data reshuffled            at every epoch (default: False).        sampler (Sampler, optional): defines the strategy to draw samples from            the dataset. If specified, ``shuffle`` must be False.        batch_sampler (Sampler, optional): like sampler, but returns a batch of            indices at a time. Mutually exclusive with batch_size, shuffle,            sampler, and drop_last.        num_workers (int, optional): how many subprocesses to use for data            loading. 0 means that the data will be loaded in the main process            (default: 0)        collate_fn (callable, optional): merges a list of samples to form a mini-batch.        pin_memory (bool, optional): If ``True``, the data loader will copy tensors            into CUDA pinned memory before returning them.        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,            if the dataset size is not divisible by the batch size. If False and            the size of dataset is not divisible by the batch size, then the last batch            will be smaller. (default: False)    """    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False):        self.dataset = dataset        self.batch_size = batch_size        self.num_workers = num_workers        self.collate_fn = collate_fn        self.pin_memory = pin_memory        self.drop_last = drop_last        if batch_sampler is not None:            if batch_size > 1 or shuffle or sampler is not None or drop_last:                raise ValueError('batch_sampler is mutually exclusive with '                                 'batch_size, shuffle, sampler, and drop_last')        if sampler is not None and shuffle:            raise ValueError('sampler is mutually exclusive with shuffle')        if batch_sampler is None:            if sampler is None:                if shuffle:                    sampler = RandomSampler(dataset)                else:                    sampler = SequentialSampler(dataset)            batch_sampler = BatchSampler(sampler, batch_size, drop_last)        self.sampler = sampler        self.batch_sampler = batch_sampler    def __iter__(self):        return DataLoaderIter(self)    def __len__(self):        return len(self.batch_sampler)
原创粉丝点击