Pytorch入门学习(六)--- 加载数据以及预处理(初步)--- 只为了更好理解流程
来源:互联网 发布:java sql注入漏洞修复 编辑:程序博客网 时间:2024/05/21 03:58
直接从Pytorch Tutorials拿过来,看看。
需要的包:
1. scikit-image: 图像io以及变形
2. pandas: 读入csv文件
数据:
faces
csv的数据形式:
总共68个人脸关键点。
image_name,part_0_x,part_0_y,part_1_x,part_1_y,part_2_x, ... ,part_67_x,part_67_y 0805personali01.jpg,27,83,27,98, ... 84,134 1084239450_e76e00b7e7.jpg,70,236,71,257, ... ,128,312
或是如图:
最简单的通过函数读取图片
# -*- coding: utf-8 -*-"""Data Loading and Processing Tutorial====================================**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_"""from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom skimage import io, transformimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# Ignore warningsimport warningswarnings.filterwarnings("ignore")plt.ion() # interactive modelandmarks_frame = pd.read_csv('faces/face_landmarks.csv')n = 65img_name = landmarks_frame.ix[n, 0]landmarks = landmarks_frame.ix[n, 1:].as_matrix().astype('float')landmarks = landmarks.reshape(-1, 2)print('Image name: {}'.format(img_name))print('Landmarks shape: {}'.format(landmarks.shape))print('First 4 Landmarks: {}'.format(landmarks[:4]))# 定义show_landmarks.def show_landmarks(image, landmarks): """Show image with landmarks""" plt.imshow(image) plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r') plt.pause(0.001) # pause a bit so that plots are updatedplt.figure()# 用 io.imread来读取图片show_landmarks(io.imread(os.path.join('faces/', img_name)), landmarks)plt.show()
通过继承Dataset
# 自定义数据集时,要继承 Dataset类。# 一般至少要有 __init__, __len__, __getitem__class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float') landmarks = landmarks.reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks} if self.transform: sample = self.transform(sample) # 采用这个,不错。 return sampleface_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv', root_dir='faces/')fig = plt.figure()for i in range(len(face_dataset)): sample = face_dataset[i] print(i, sample['image'].shape, sample['landmarks'].shape) # 对于dict,可以使用['key']来获取值。 ax = plt.subplot(1, 4, i + 1) plt.tight_layout() ax.set_title('Sample #{}'.format(i)) ax.axis('off') show_landmarks(**sample) # 关键字参数传参 if i == 3: plt.show() break
第三部分:图像变换的基本写法(继承object)
主要有:
1. 放缩
2. 随机裁剪
3. 将numpy 的图像数据转换为 tensor.
其他:
1. 这里将类变成可调用对象。通过类中实现 call()
比如:
tsfm = Transform( init_params)
transformed_sample = tsfm(call_params)
就是 类的实例(参数)。
知识点:
1. isinstance(output_size, (int, tuple))**后面的参数可以是tuple的。以此来同时验证多种类型。
2. numpy和pytorch的图像要进行swap axis.
# swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
class Rescale(object): """Rescale the image in a sample to a given size. Args: output_size (tuple or tuple): Desired output size. If tuple, output is matched to output_size. If int, smaller of image edges is matched to output_size keeping aspect ratio the same. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) # 不错。 self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size * h / w, self.output_size else: new_h, new_w = self.output_size, self.output_size * w / h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) # h and w are swapped for landmarks because for images, # x and y axes are axis 1 and 0 respectively landmarks = landmarks * [new_w / w, new_h / h] return {'image': img, 'landmarks': landmarks}class RandomCrop(object): """Crop randomly the image in a sample. Args: output_size (tuple or int): Desired output size. If int, square crop is made. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks}class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}
第四部分:利用torchvision来组合图像变换
如果我们想对图像进行多个变换,那么可以用 torchvision.transforms.Compose.
"""注意: 1. FaceLandmarksDataset是在第二部分我们自定义的类(继承自Dataset)。2. transforms.Compose可以将多个操作组合。3. 每个操作都是一个类,都有 __call__函数。__call__函数的参数是sample。我们只需要 MyTransforms(init_params)即可。"""transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv', root_dir='faces/', transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor() # to tensor初始化函数没有参数 ]))# 很ugly的加载数据集for i in range(len(transformed_dataset)): sample = transformed_dataset[i] print(i, sample['image'].size(), sample['landmarks'].size()) if i == 3: break
通过这种for循环处理数据集的方法有很多问题:
1. 没有将数据以成批的形式处理
2. 没有shuffle数据
3. 没有并行加载数据,用multiprocessing
workers.
第五部分:使用DataLoader
torch.utils.data.DataLoader
是一个迭代器,具有上面所有功能特性。
方式
dataloader = DataLoader(经过各种变换的自定义的数据集, batch_size=4, shuffle=True, num_workers=4)
#定义dataloaddataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)# Helper function to show a batchdef show_landmarks_batch(sample_batched): """Show image with landmarks for a batch of samples.""" images_batch, landmarks_batch = \ sample_batched['image'], sample_batched['landmarks'] batch_size = len(images_batch) im_size = images_batch.size(2) grid = utils.make_grid(images_batch) plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size): plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size, landmarks_batch[i, :, 1].numpy(), s=10, marker='.', c='r') plt.title('Batch from dataloader')for i_batch, sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) # observe 4th batch and stop. if i_batch == 3: plt.figure() show_landmarks_batch(sample_batched) plt.axis('off') plt.ioff() plt.show() break
总结
综上有:
# -*- coding: utf-8 -*-############# 第一步: 准备 ##################"""Data Loading and Processing Tutorial====================================**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_"""from __future__ import print_function, divisionimport osimport torchimport pandas as pdfrom skimage import io, transformimport numpy as npimport matplotlib.pyplot as pltfrom torch.utils.data import Dataset, DataLoaderfrom torchvision import transforms, utils# Ignore warningsimport warningswarnings.filterwarnings("ignore")plt.ion() # interactive modelandmarks_frame = pd.read_csv('faces/face_landmarks.csv')n = 65img_name = landmarks_frame.ix[n, 0]landmarks = landmarks_frame.ix[n, 1:].as_matrix().astype('float')landmarks = landmarks.reshape(-1, 2)print('Image name: {}'.format(img_name))print('Landmarks shape: {}'.format(landmarks.shape))print('First 4 Landmarks: {}'.format(landmarks[:4]))# 定义show_landmarks.def show_landmarks(image, landmarks): """Show image with landmarks""" plt.imshow(image) plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r') plt.pause(0.001) # pause a bit so that plots are updatedplt.figure()# 用 io.imread来读取图片show_landmarks(io.imread(os.path.join('faces/', img_name)), landmarks)plt.show()################## 第二步:自定义dataset类 ####################class FaceLandmarksDataset(Dataset): """Face Landmarks dataset.""" def __init__(self, csv_file, root_dir, transform=None): """ Args: csv_file (string): Path to the csv file with annotations. root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.landmarks_frame = pd.read_csv(csv_file) self.root_dir = root_dir self.transform = transform def __len__(self): return len(self.landmarks_frame) def __getitem__(self, idx): img_name = os.path.join(self.root_dir, self.landmarks_frame.ix[idx, 0]) image = io.imread(img_name) landmarks = self.landmarks_frame.ix[idx, 1:].as_matrix().astype('float') landmarks = landmarks.reshape(-1, 2) sample = {'image': image, 'landmarks': landmarks} if self.transform: sample = self.transform(sample) return sample############# 第三步:自定义transforms类 ###################### 注意:自定义操作类,都扩展了 __call__函数,以便在后面 transforms.Compose调用。# 它们都继承object即可。class Rescale(object): """Rescale the image in a sample to a given size. Args: output_size (tuple or tuple): Desired output size. If tuple, output is matched to output_size. If int, smaller of image edges is matched to output_size keeping aspect ratio the same. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size * h / w, self.output_size else: new_h, new_w = self.output_size, self.output_size * w / h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) # h and w are swapped for landmarks because for images, # x and y axes are axis 1 and 0 respectively landmarks = landmarks * [new_w / w, new_h / h] return {'image': img, 'landmarks': landmarks}class RandomCrop(object): """Crop randomly the image in a sample. Args: output_size (tuple or int): Desired output size. If int, square crop is made. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] landmarks = landmarks - [left, top] return {'image': image, 'landmarks': landmarks}class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): image, landmarks = sample['image'], sample['landmarks'] # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'landmarks': torch.from_numpy(landmarks)}############第四步: 数据集实例(自定义dataset类 + 各种变换组合)#############transformed_dataset = FaceLandmarksDataset(csv_file='faces/face_landmarks.csv', root_dir='faces/', transform=transforms.Compose([ Rescale(256), RandomCrop(224), ToTensor() ]))############# 第五步: 使用DataLoader并行加载 ###################dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)# Helper function to show a batchdef show_landmarks_batch(sample_batched): """Show image with landmarks for a batch of samples.""" images_batch, landmarks_batch = \ sample_batched['image'], sample_batched['landmarks'] batch_size = len(images_batch) im_size = images_batch.size(2) grid = utils.make_grid(images_batch) plt.imshow(grid.numpy().transpose((1, 2, 0))) for i in range(batch_size): plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size, landmarks_batch[i, :, 1].numpy(), s=10, marker='.', c='r') plt.title('Batch from dataloader')# 用dataLoader类就是一个batch的形式了。for i_batch, sample_batched in enumerate(dataloader): print(i_batch, sample_batched['image'].size(), sample_batched['landmarks'].size()) # observe 4th batch and stop. if i_batch == 3: plt.figure() show_landmarks_batch(sample_batched) plt.axis('off') plt.ioff() plt.show() break
可以看到,这种方法需要自己写 各种transforms类,比较繁琐。但能很清晰地展示整个流程。下篇会介绍在真正项目中用到的方法。
阅读全文
0 0
- Pytorch入门学习(六)--- 加载数据以及预处理(初步)--- 只为了更好理解流程
- Pytorch入门学习(七)---- 数据加载和预处理的通用方法
- Pytorch学习笔记(六)
- PyTorch入门学习(一)
- pytorch学习笔记(六):自定义Datasets
- pytorch学习笔记(六):自定义Datasets
- Pytorch学习入门(一)--- 从torch7跳坑至pytorch --- Tensor
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解
- 基于PyTorch的深度学习入门教程(六)——数据并行化
- Pytorch学习入门(二)--- Autograd
- Pytorch学习笔记(二)自己加载单通道图片用作数据集训练
- PyTorch入门(2)
- PyTorch学习系列(二)——数据预处理torchvision.transforms
- pytorch入门(3)pytorch-seq2seq模型
- Linux初步学习 (六)
- springMVC初步学习(六)
- pytorch使用(一)处理并加载自己的数据
- Java面试题全集(上)
- Jdpush
- ASCLL码
- iFunk S机械键盘游戏本登陆苏宁
- Linux权限问题
- Pytorch入门学习(六)--- 加载数据以及预处理(初步)--- 只为了更好理解流程
- golang 空结构体struct{}解析
- 安卓 相册和拍照 takephoto的使用 踩过的坑
- HTML5+ API Reference
- centos7安装vsftpd 开启ftp
- Java中内部类--静态和非静态
- vue开源项目库汇总
- 【Hbase】初识Hbase,单节点安装
- Unity Shader 网站收集