pytorch学习笔记(六):自定义Datasets
来源:互联网 发布:图像拼接融合算法 编辑:程序博客网 时间:2024/06/14 05:06
什么是Datasets:
在输入流水线中,我们看到准备数据的代码是这么写的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
。datasets.CIFAR10
就是一个Datasets
子类,data
是这个类的一个实例。
为什么要定义Datasets:
PyTorch
提供了一个工具函数torch.utils.data.DataLoader
。通过这个类,我们在准备mini-batch
的时候可以多线程并行处理,这样可以加快准备数据的速度。Datasets
就是构建这个类的实例的参数之一。
如何自定义Datasets
下面是一个自定义Datasets的框架
class CustomDataset(data.Dataset):#需要继承data.Dataset def __init__(self): # TODO # 1. Initialize file path or list of file names. pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #这里需要注意的是,第一步:read one data,是一个data pass def __len__(self): # You should change 0 to the total size of your dataset. return 0
下面看一下官方MNIST
的例子(代码被缩减,只留下了重要的部分):
class MNIST(data.Dataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set if download: self.download() if not self._check_exists(): raise RuntimeError('Dataset not found.' + ' You can use download=True to download it') if self.train: self.train_data, self.train_labels = torch.load( os.path.join(root, self.processed_folder, self.training_file)) else: self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) def __getitem__(self, index): if self.train: img, target = self.train_data[index], self.train_labels[index] else: img, target = self.test_data[index], self.test_labels[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): if self.train: return 60000 else: return 10000
2 0
- pytorch学习笔记(六):自定义Datasets
- pytorch学习笔记(六):自定义Datasets
- Pytorch学习笔记(六)
- pytorch学习笔记(1)--pytorch张量
- pytorch 学习笔记之自定义 Module
- Pytorch 学习笔记之自定义 Module
- Pytorch学习笔记(一)
- Pytorch学习笔记(二)
- Pytorch学习笔记(三)
- Pytorch学习笔记(四)
- Pytorch学习笔记(五)
- PyTorch代码学习-torchvision.datasets中folder.py
- pytorch学习笔记(九):PyTorch结构介绍
- pytorch学习笔记(九):PyTorch结构介绍
- Pytorch学习笔记(一):pytorch的安装-Ubuntu14.04
- pytorch学习笔记(十七):python 端扩展 pytorch
- pytorch学习笔记(十八):C 语言扩展 pytorch
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- (转)Android中属性动画和补间动画的区别
- android studio初始化
- 方法重载与方法重写的区别
- 完整轮播图实现过程
- 网课内容--解码BMP与创建纹理
- pytorch学习笔记(六):自定义Datasets
- &和&& |和||
- Block外给self加上weak,那不就释放了吗
- WRF-DA代码编译与安装(二)——WRF-DA模块的编译与安装
- 蓝桥 历届试题 地宫取宝
- css float详解
- 【剑指offer】两个链表的第一个公共节点
- Java知识点
- 关于微服的一些资料