PyTorch: Data Loading and Processing Tutorial 数据加载和处理教程

来源:互联网 发布:淘宝如何提高售后评分 编辑:程序博客网 时间:2024/05/21 03:18

将数据级封装成类Dataset

首先预加载一些包,特别是torch.utils.data.Dataset

from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom skimage import io, transformimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# Ignore warningsimport warningswarnings.filterwarnings("ignore")plt.ion()   # interactive mode

数据集合转化成Dataset这个类,然后必须有
__init__来加载数据集,
__len__来获取数据集的数据数量,用于for循环的次数,
__getitem__来索引数据集中的某条数据,
可选参数transform=None来传递数据预处理(resize,crop,to_torch_tensor)等预处理函数。

class FaceLandmarksDataset(Dataset):    """Face Landmarks dataset."""    def __init__(self, csv_file, root_dir, transform=None):        """        Args:            csv_file (string): Path to the csv file with annotations.            root_dir (string): Directory with all the images.            transform (callable, optional): Optional transform to be applied                on a sample.        """        self.landmarks_frame = pd.read_csv(csv_file)        self.root_dir = root_dir        self.transform = transform    def __len__(self):        return len(self.landmarks_frame)    def __getitem__(self, idx):        img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0])        image = io.imread(img_name)        landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float')        landmarks = landmarks.reshape(-1, 2)        sample = {'image': image, 'landmarks': landmarks}        if self.transform:            sample = self.transform(sample)        return sample

其中传入的参数transform的函数要写成类的形式,类中函数__init__来传递参数,避免重复加载参数,__call__来调用函数。比如:

class Rescale(object):    """Rescale the image in a sample to a given size.    Args:        output_size (tuple or tuple): Desired output size. If tuple, output is            matched to output_size. If int, smaller of image edges is matched            to output_size keeping aspect ratio the same.    """    def __init__(self, output_size):        assert isinstance(output_size, (int, tuple))        self.output_size = output_size    def __call__(self, sample):        image, landmarks = sample['image'], sample['landmarks']        h, w = image.shape[:2]        if isinstance(self.output_size, int):            if h > w:                new_h, new_w = self.output_size * h / w, self.output_size            else:                new_h, new_w = self.output_size, self.output_size * w / h        else:            new_h, new_w = self.output_size        new_h, new_w = int(new_h), int(new_w)        img = transform.resize(image, (new_h, new_w))        # h and w are swapped for landmarks because for images,        # x and y axes are axis 1 and 0 respectively        landmarks = landmarks * [new_w / w, new_h / h]        return {'image': img, 'landmarks': landmarks}class RandomCrop(object):    """Crop randomly the image in a sample.    Args:        output_size (tuple or int): Desired output size. If int, square crop            is made.    """    def __init__(self, output_size):        assert isinstance(output_size, (int, tuple))        if isinstance(output_size, int):            self.output_size = (output_size, output_size)        else:            assert len(output_size) == 2            self.output_size = output_size    def __call__(self, sample):        image, landmarks = sample['image'], sample['landmarks']        h, w = image.shape[:2]        new_h, new_w = self.output_size        top = np.random.randint(0, h - new_h)        left = np.random.randint(0, w - new_w)        image = image[top: top + new_h,                      left: left + new_w]        landmarks = landmarks - [left, top]        return {'image': image, 'landmarks': landmarks}class ToTensor(object):    """Convert ndarrays in sample to Tensors."""    def __call__(self, sample):        image, landmarks = sample['image'], sample['landmarks']        # swap color axis because        # numpy image: H x W x C        # torch image: C X H X W        image = image.transpose((2, 0, 1))        return {'image': torch.from_numpy(image),                'landmarks': torch.from_numpy(landmarks)}

构建数据加载器DataLoader

from torch.utils.data import DataLoaderdataloader = DataLoader(transformed_dataset, batch_size=4,                        shuffle=True, num_workers=4)

使用torchvision使处理更方便

import torchfrom torchvision import transforms, datasetsdata_transform = transforms.Compose([        transforms.RandomSizedCrop(224),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        transforms.Normalize(mean=[0.485, 0.456, 0.406],                             std=[0.229, 0.224, 0.225])    ])hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',                                           transform=data_transform)dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,                                             batch_size=4, shuffle=True,                                             num_workers=4)

参考文献

http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class

阅读全文
0 0
原创粉丝点击