PyTorch代码学习-dataloader
来源:互联网 发布:淘宝上怎么做网络推广 编辑:程序博客网 时间:2024/06/03 17:19
import torch# torch.multiprocessing:本地多进程模块的包装器import torch.multiprocessing as multiprocessing# 采样器模块:无放回抽样:随机抽样(RandomSampling)和系统抽样(SystematicSampling)。# 有放回抽样:随机抽样(RepetitionRandomSampling)。from .sampler import SequentialSampler, RandomSampler, BatchSamplerimport collections'''collections模块在内置数据类型的基础上,提供了几个额外的数据类型:1.namedtuple(): 生成可以使用名字来访问元素内容的tuple子类2.deque: 双端队列,可以快速的从另外一侧追加和推出对象3.Counter: 计数器,主要用来计数4.OrderedDict: 有序字典5.defaultdict: 带有默认值的字典'''import sys# traceback模块被用来跟踪异常返回信息import traceback# threading 多线程控制和处理import threadingfrom torch._six import string_classes# python 版本查询if sys.version_info[0] == 2: import Queue as queueelse: import queue_use_shared_memory = False"""Whether to use shared memory in default_collate"""class ExceptionWrapper(object): "Wraps an exception plus traceback to communicate across threads" def __init__(self, exc_info): self.exc_type = exc_info[0] # format_exception:输出异常栈 # join:字符串操作函数,链接字符 '''a="abcd" >>> ",".join(a) 'a,b,c,d' ''' self.exc_msg = "".join(traceback.format_exception(*exc_info))def _worker_loop(dataset, index_queue, data_queue, collate_fn): global _use_shared_memory _use_shared_memory = True torch.set_num_threads(1) while True: r = index_queue.get() if r is None: data_queue.put(None) break idx, batch_indices = r try: samples = collate_fn([dataset[i] for i in batch_indices]) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples))def _pin_memory_loop(in_queue, out_queue, done_event): while True: try: r = in_queue.get() except Exception: if done_event.is_set(): return raise if r is None: break if isinstance(r[1], ExceptionWrapper): out_queue.put(r) continue idx, batch = r try: batch = pin_memory_batch(batch) except Exception: out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: out_queue.put((idx, batch))numpy_type_map = { 'float64': torch.DoubleTensor, 'float32': torch.FloatTensor, 'float16': torch.HalfTensor, 'int64': torch.LongTensor, 'int32': torch.IntTensor, 'int16': torch.ShortTensor, 'int8': torch.CharTensor, 'uint8': torch.ByteTensor,}def default_collate(batch): "Puts each data field into a tensor with outer dimension batch size" if torch.is_tensor(batch[0]): out = None if _use_shared_memory: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = batch[0].storage()._new_shared(numel) out = batch[0].new(storage) return torch.stack(batch, 0, out=out) elif type(batch[0]).__module__ == 'numpy': elem = batch[0] if type(elem).__name__ == 'ndarray': return torch.stack([torch.from_numpy(b) for b in batch], 0) if elem.shape == (): # scalars py_type = float if elem.dtype.name.startswith('float') else int return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) elif isinstance(batch[0], int): return torch.LongTensor(batch) elif isinstance(batch[0], float): return torch.DoubleTensor(batch) elif isinstance(batch[0], string_classes): return batch elif isinstance(batch[0], collections.Mapping): return {key: default_collate([d[key] for d in batch]) for key in batch[0]} elif isinstance(batch[0], collections.Sequence): transposed = zip(*batch) return [default_collate(samples) for samples in transposed] raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" .format(type(batch[0]))))def pin_memory_batch(batch): if torch.is_tensor(batch): return batch.pin_memory() elif isinstance(batch, string_classes): return batch elif isinstance(batch, collections.Mapping): return {k: pin_memory_batch(sample) for k, sample in batch.items()} elif isinstance(batch, collections.Sequence): return [pin_memory_batch(sample) for sample in batch] else: return batchclass DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.index_queue = multiprocessing.SimpleQueue() self.data_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn)) for _ in range(self.num_workers)] for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() if self.pin_memory: in_data = self.data_queue self.data_queue = queue.Queue() self.pin_thread = threading.Thread( target=_pin_memory_loop, args=(in_data, self.data_queue, self.done_event)) self.pin_thread.daemon = True self.pin_thread.start() # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() def __len__(self): return len(self.batch_sampler) def __next__(self): if self.num_workers == 0: # same-process loading indices = next(self.sample_iter) # may raise StopIteration batch = self.collate_fn([self.dataset[i] for i in indices]) if self.pin_memory: batch = pin_memory_batch(batch) return batch # check if the next sample has already been generated if self.rcvd_idx in self.reorder_dict: batch = self.reorder_dict.pop(self.rcvd_idx) return self._process_next_batch(batch) if self.batches_outstanding == 0: self._shutdown_workers() raise StopIteration while True: assert (not self.shutdown and self.batches_outstanding > 0) idx, batch = self.data_queue.get() self.batches_outstanding -= 1 if idx != self.rcvd_idx: # store out-of-order samples self.reorder_dict[idx] = batch continue return self._process_next_batch(batch) next = __next__ # Python 2 compatibility def __iter__(self): return self def _put_indices(self): assert self.batches_outstanding < 2 * self.num_workers indices = next(self.sample_iter, None) if indices is None: return self.index_queue.put((self.send_idx, indices)) self.batches_outstanding += 1 self.send_idx += 1 def _process_next_batch(self, batch): self.rcvd_idx += 1 self._put_indices() if isinstance(batch, ExceptionWrapper): raise batch.exc_type(batch.exc_msg) return batch def __getstate__(self): # TODO: add limited pickling support for sharing an iterator # across multiple threads for HOGWILD. # Probably the best way to do this is by moving the sample pushing # to a separate thread and then just sharing the data queue # but signalling the end is tricky without a non-blocking API raise NotImplementedError("DataLoaderIterator cannot be pickled") def _shutdown_workers(self): if not self.shutdown: self.shutdown = True self.done_event.set() for _ in self.workers: self.index_queue.put(None) def __del__(self): if self.num_workers > 0: self._shutdown_workers()class DataLoader(object): """ Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset. Arguments: dataset (Dataset): dataset from which to load the data. batch_size (int, optional): how many samples per batch to load (default: 1). shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch (default: False). sampler (Sampler, optional): defines the strategy to draw samples from the dataset. If specified, ``shuffle`` must be False. batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process (default: 0) collate_fn (callable, optional): merges a list of samples to form a mini-batch. pin_memory (bool, optional): If ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): self.dataset = dataset self.batch_size = batch_size self.num_workers = num_workers self.collate_fn = collate_fn self.pin_memory = pin_memory self.drop_last = drop_last if batch_sampler is not None: if batch_size > 1 or shuffle or sampler is not None or drop_last: raise ValueError('batch_sampler is mutually exclusive with ' 'batch_size, shuffle, sampler, and drop_last') if sampler is not None and shuffle: raise ValueError('sampler is mutually exclusive with shuffle') if batch_sampler is None: if sampler is None: if shuffle: sampler = RandomSampler(dataset) else: sampler = SequentialSampler(dataset) batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.sampler = sampler self.batch_sampler = batch_sampler def __iter__(self): return DataLoaderIter(self) def __len__(self): return len(self.batch_sampler)
阅读全文
0 0
- PyTorch代码学习-dataloader
- pytorch学习笔记(十四): DataLoader源码阅读
- PyTorch代码学习-ImageNET训练
- PyTorch学习
- pytorch 学习
- PyTorch学习—PyTorch是什么?
- pytorch学习笔记(十六):pytorch 写代码时应该注意
- Pytorch学习(十)---解读Neural Style代码
- PyTorch代码学习-torchvision.datasets中folder.py
- pytorch学习笔记(1)--pytorch张量
- PyTorch学习记录-1PyTorch安装
- pytorch 学习笔记(一)
- pytorch 学习笔记(一)
- pytorch 学习笔记(一)
- pytorch学习笔记
- PyTorch学习历程
- pytorch 学习笔记(一)
- 【pytorch】迁移学习
- HDU 5929 Basic Data Structure(模拟栈+规律)
- coursera算法笔记:QuickFind中union隐蔽错误
- 微信小程序用console.log打印json数据
- zabbix-sender主动发送数据给zabbix-server
- Qml翻转效果
- PyTorch代码学习-dataloader
- TCP/IP协议四层模型
- 腾讯应用宝签约宋仲基 打造移动生活娱乐第一站
- 要抢会计饭碗的FreshBooks,如何为创业狗们排忧解难?
- 腾讯“互联网+警务”方案为智慧城市生活构建注入新动力
- 史上最简单的拍摄“漂浮少女”教程 只需要vivo Xplay5旗舰版
- 写一个宏可以将一个数字的奇数位和偶数位交换
- 2018应届生招聘(面经)
- 第四篇 elasticsearch的基本分布式架构