PyTorch: Data Loading and Processing Tutorial 数据加载和处理教程
来源:互联网 发布:淘宝如何提高售后评分 编辑:程序博客网 时间:2024/05/21 03:18
将数据级封装成类Dataset
首先预加载一些包,特别是torch.utils.data.Dataset
from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom skimage import io, transformimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# Ignore warningsimport warningswarnings.filterwarnings("ignore")plt.ion() # interactive mode
数据集合转化成Dataset这个类,然后必须有 __init__
来加载数据集, __len__
来获取数据集的数据数量,用于for循环的次数, __getitem__
来索引数据集中的某条数据,
可选参数transform=None
来传递数据预处理(resize,crop,to_torch_tensor)等预处理函数。
class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float') landmarks = landmarks.reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks} if self.transform: sample = self.transform(sample) return sample
其中传入的参数transform
的函数要写成类的形式,类中函数__init__
来传递参数,避免重复加载参数,__call__
来调用函数。比如:
class Rescale(object): """Rescale the image in a sample to a given size. Args: output_size (tuple or tuple): Desired output size. If tuple, output is matched to output_size. If int, smaller of image edges is matched to output_size keeping aspect ratio the same. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size * h / w, self.output_size else: new_h, new_w = self.output_size, self.output_size * w / h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) # h and w are swapped for landmarks because for images, # x and y axes are axis 1 and 0 respectively landmarks = landmarks * [new_w / w, new_h / h] return {'image': img, 'landmarks': landmarks}class RandomCrop(object): """Crop randomly the image in a sample. Args: output_size (tuple or int): Desired output size. If int, square crop is made. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks}class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
构建数据加载器DataLoader
from torch.utils.data import DataLoaderdataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
使用torchvision
使处理更方便
import torchfrom torchvision import transforms, datasetsdata_transform = transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train', transform=data_transform)dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset, batch_size=4, shuffle=True, num_workers=4)
参考文献
http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class
阅读全文
0 0
- PyTorch: Data Loading and Processing Tutorial 数据加载和处理教程
- Processing 教程(16)- 图片加载和处理
- pytorch使用(一)处理并加载自己的数据
- pytorch-tutorial
- Spring MVC快速教程:表单处理 Spring MVC Fast Tutorial: Form Processing
- 优雅地处理加载中(loading),重试(retry)和无数据(empty)等---LoadSir
- 优雅地处理加载中(loading),重试(retry)和无数据(empty)等
- Data Pre-processing(数据预处理)
- 小波和信号处理 Wavelets and Signal Processing
- 大数据数据分析的大规模并行处理模型 Big Data Massive Parallel Processing(MPP) Model
- (转)4.3加载和保存(Loading and Saving)
- (转)4.3加载和保存(Loading and Saving)
- ABAP data formating and control level processing
- quality assessment data and processing procedure
- Big Data Processing:Map and Reduce
- iOS教程:如何使用Core Data – 预加载和引入数据
- iOS教程:如何使用Core Data – 预加载和引入数据
- 8.5.5 Bulk Data Loading for InnoDB Tables 批量数据加载
- spring配置hibernate事务aop不生效的问题
- AtCoder Regular Contest 069 F
- android资源之res/raw和assets的异同
- 动态规划之-用局部最优和全局最优实现时间优化
- com.android.internal.R不存在
- PyTorch: Data Loading and Processing Tutorial 数据加载和处理教程
- Jhipster禁用noliquibase
- Java之基本数据类型(8种还是9种)-yellowcong
- 2.一元多项式相乘
- wget命令
- 格式化磁盘分区详解
- RHCS
- Cpp判断是否为局域网IP
- 股票学习前言