pytorch自定义Dataset并使用torchvision的Transform
来源:互联网 发布:avmoo2017最新域名 编辑:程序博客网 时间:2024/06/06 15:37
最近用了pytorch, 使用上比Tensorflow爽的多,尤其是在读取数据的部分,冗长而繁杂的api令人望而却步,而且由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦
pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html
但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:
class FaceLandmarkDataset(Dataset): def __len__(self) -> int: return len(self.landmarks_frame) def __init__(self, csv_file: str, root_dir: str, transform=None) -> None: super().__init__() self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __getitem__(self, index:int): img_name = self.landmarks_frame.ix[index, 0] img_path = os.path.join('./faces', img_name) with Image.open(img_path) as img: image = img.convert('RGB') landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float') landmarks = np.reshape(landmarks,newshape=(-1,2)) if self.transform is not None: image = self.transform(image) return image, landmarks
这几行的重点在__getitem__
函数里面,这个函数内部实现了读取硬盘上的图片文件,这里一定要注意,使用PIL的Image库进行读取,然后convert到RGB值
另外,读取到的numpy的ndarray结构的数据不需要显式转换成pytorch的Tensor,后续的DataLoader会自动替你转换(当然转换一下也没差啦,反正也就是torch.from_numpy
一下啦
全部代码如下:
from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom PIL import Imageimport numpy as npfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsclass FaceLandmarkDataset(Dataset): def __len__(self) -> int: return len(self.landmarks_frame) def __init__(self, csv_file: str, root_dir: str, transform=None) -> None: super().__init__() self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __getitem__(self, index:int): img_name = self.landmarks_frame.ix[index, 0] img_path = os.path.join('./faces', img_name) with Image.open(img_path) as img: image = img.convert('RGB') landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float') landmarks = np.reshape(landmarks,newshape=(-1,2)) if self.transform is not None: image = self.transform(image) return image, landmarkstrans = transforms.Compose(transforms = [ transforms.RandomSizedCrop(size=128), transforms.ToTensor()])face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', root_dir='faces', transform= trans)loader = DataLoader(dataset = face_dataset, batch_size=4,shuffle=True,num_workers=4)
获取得到的DataLoader是个iterator,可以直接循环调用
阅读全文
0 0
- pytorch自定义Dataset并使用torchvision的Transform
- PyTorch学习-torchvision transform
- pytorch-torchvision.models
- pytorch-torchvision transforms
- pytorch torchvision.datasets.CocoCaptions on Linux
- win10+cuda8+cudnn5.1+Anaconda3+pytorch+torchvision
- ubuntu14.04安装Pytorch 和 torchvision
- Hive自定义函数与transform的使用
- pytorch使用(一)处理并加载自己的数据
- 自定义DataSet,并写入数据
- PyTorch学习系列(二)——数据预处理torchvision.transforms
- PyTorch代码学习-torchvision.datasets中folder.py
- pytorch-custom dataset
- pytorch使用(二)自定义网络
- 【pytorch源码赏析】Dataset in pytorch
- PyTorch使用指定的GPU
- pytorch如何自定义自己的MyDatasets
- DataSet自定义添加table并添加数据
- 五、Git-管理修改
- Java 并发:Executors 和线程池
- Java虚拟机运行时数据区域
- JS中的闭包(Closure)
- DrawableLayout实现仿QQ侧滑菜单+HttURLConnection_XListView_DrawerLayout_ImageLoader
- pytorch自定义Dataset并使用torchvision的Transform
- CCF认证 201312-1 出现次数最多的数
- 学习笔记:阿里云ECS部署web项目的常见问题及解决方法
- 修改Oracle 11g中scott账户锁定和密码
- [尺取法]2017 ACM/ICPC Asia Regional Shenyang Online 1012
- cxf调用WebService时出现No operation was found with the name {http://impl.server.test.com/}helloWorld
- Shiro创建FilterChain过程详解
- 红黑树介绍
- HIS-门急诊模块之系统集成工作摘要