pytorch自定义Dataset并使用torchvision的Transform

来源:互联网 发布:avmoo2017最新域名 编辑:程序博客网 时间:2024/06/06 15:37

最近用了pytorch, 使用上比Tensorflow爽的多,尤其是在读取数据的部分,冗长而繁杂的api令人望而却步,而且由于Tensorflow不支持与numpy的无缝切换,导致难以使用现成的pandas等格式化数据读取工具,造成了很多不必要的麻烦

pytorch自定义读取数据和进行Transform的部分请见文档:
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html

但是按照文档中所描述所完成的自定义Dataset只能够使用自定义的Transform步骤,而torchvision包中已经给我们提供了很多图像transform步骤的实现,为了使用这些已经实现的Transform步骤,我们可以使用如下方法定义Dataset:

class FaceLandmarkDataset(Dataset):    def __len__(self) -> int:        return len(self.landmarks_frame)    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:        super().__init__()        self.landmarks_frame = pd.read_csv(csv_file)        self.root_dir = root_dir        self.transform = transform    def __getitem__(self, index:int):        img_name = self.landmarks_frame.ix[index, 0]        img_path = os.path.join('./faces', img_name)        with Image.open(img_path) as img:            image = img.convert('RGB')        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')        landmarks = np.reshape(landmarks,newshape=(-1,2))        if self.transform is not None:            image = self.transform(image)        return image, landmarks

这几行的重点在__getitem__函数里面,这个函数内部实现了读取硬盘上的图片文件,这里一定要注意,使用PIL的Image库进行读取,然后convert到RGB值

另外,读取到的numpy的ndarray结构的数据不需要显式转换成pytorch的Tensor,后续的DataLoader会自动替你转换(当然转换一下也没差啦,反正也就是torch.from_numpy一下啦

全部代码如下:

from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom PIL import Imageimport numpy as npfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transformsclass FaceLandmarkDataset(Dataset):    def __len__(self) -> int:        return len(self.landmarks_frame)    def __init__(self, csv_file: str, root_dir: str, transform=None) -> None:        super().__init__()        self.landmarks_frame = pd.read_csv(csv_file)        self.root_dir = root_dir        self.transform = transform    def __getitem__(self, index:int):        img_name = self.landmarks_frame.ix[index, 0]        img_path = os.path.join('./faces', img_name)        with Image.open(img_path) as img:            image = img.convert('RGB')        landmarks = self.landmarks_frame.as_matrix()[index, 1:].astype('float')        landmarks = np.reshape(landmarks,newshape=(-1,2))        if self.transform is not None:            image = self.transform(image)        return image, landmarkstrans = transforms.Compose(transforms = [    transforms.RandomSizedCrop(size=128),    transforms.ToTensor()])face_dataset = FaceLandmarkDataset(csv_file='faces/face_landmarks.csv', root_dir='faces', transform= trans)loader = DataLoader(dataset = face_dataset, batch_size=4,shuffle=True,num_workers=4)

获取得到的DataLoader是个iterator,可以直接循环调用

原创粉丝点击