pytorch使用(一)处理并加载自己的数据

来源:互联网 发布:三砖淘宝店铺 编辑:程序博客网 时间:2024/06/01 10:16

pytorch使用:目录


pytorch使用(一)数据处理

个人认为,数据处理或许是在完成一篇论文中最耗费时间的,特别是大多情况下,需要在很多个库上做实验。

pytorch官方支持很多库,使用torchvision来完成数据的处理,点这里可以看到支持的库并不是很多。在这里,我将结合一个实例说明如何使用pytorch来处理自己的数据,任务是一个分析双臂运动的,检测6个关节点的运动。输入是连续三帧的检测结果以及计算的光流,也就是$3*6+2*2=22$张heatmap,输出是中间帧的检测结果,也就是6张heatmap。

把原始数据处理为模型使用的数据需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分别可以理解为数据处理格式的定义、数据处理和数据加载。

1. 数据预处理torchvision.transforms

pytorch使用torchvision.transforms实现数据的预处理,包括中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等,建议整体浏览一下这一部分的官方手册,非常有用,数据处理很方便。

先转换为张量,然后正则化:

import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(),                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])#img = transform(img)

2. 数据读取,构建Dataset子类

参考:http://blog.csdn.net/victoriaw/article/details/72356453

如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。我们的将数据列表放在train.txttest.txt中,将不同类型的数据的路径放在path.txt中,所以在类的init函数中有path_file和 list_file两个变量

在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是lengetitem:
- len返回数据集的大小
- getitem实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。

末尾有自己写的一个Dataset子类的定义文件。

3. 数据加载

torch.utils.data.DataLoader()函数,合成数据并且提供迭代访问。主要由两部分组成:
- dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:

- batch-size。样本每个batch的大小,默认为1。- shuffle。是否打乱数据,默认为False。- num_workers。数据分为几个线程处理默认为0。- sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False

使用:

import torchfrom datagen import MyDatasettrainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)

以下是定义class MyDataset文件datagen.py, 其中有__init__(self, path_file, list_file,numJoints,type)__getitem__(self, idx)__len__(self)三个函数,__getitem__返回一个(22,256,256)的输入和一个(6,256,256)的标签。
'''Load data'''import numpy as npfrom PIL import Image#import cv2import torchimport torch.utils.data as dataimport torchvision.transforms as transformsclass MyDataset(data.Dataset):    def __init__(self, path_file, list_file,numJoints,type):        '''        Args:          path_file: (str) heatmap and optical file location          list_file: (str) path to index file.          numJoints: (int) number of joints          type: (boolean) use pose flow(true) or optical flow(false)        '''        self.numJoints = numJoints        # read heatmap and optical path        with open(path_file) as f:            paths = f.readlines()        for path in paths:            splited = path.strip().split()            if splited[0]=='resPath':                self.resPath = splited[1]            elif splited[0]=='gtPath':                self.gtPath = splited[1]            elif splited[0]=='opticalFlowPath':                self.opticalFlowPath = splited[1]            elif splited[0]=='poseFlowPath':                self.poseFlowPath = splited[1]        if type:            self.flowPath = self.poseFlowPath        else:            self.flowPath = self.opticalFlowPath        #read list        with open(list_file) as f:            self.list = f.readlines()            self.num_samples = len(self.list)def __getitem__(self, idx):    '''    load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output    :param idx: (int) image index    :return:        input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image        output: the ground truth    '''    input = []    output = []    # load heatmaps of 3 image    for im in range(3):        for map in range(6):            curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp'            heatmap = Image.open(curResPath)            heatmap.load()            heatmap = np.asarray(heatmap, dtype='float') / 255            input.append(heatmap)    # load 2 flow    for flow in range(2):        curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg'        flowX = Image.open(curFlowXPath)        flowX.load()        flowX = np.asarray(flowX, dtype='float')        curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg'        flowY = Image.open(curFlowYPath)        flowY.load()        flowY = np.asarray(flowY, dtype='float')        input.append(flowX)        input.append(flowY)    # load groundtruth    for map in range(6):        curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp'        heatmap = Image.open(curResPath)        heatmap.load()        heatmap = np.asarray(heatmap, dtype='float') / 255        output.append(heatmap)    input = torch.Tensor(input)    output = torch.Tensor(output)    return input,outputdef __len__(self):    return self.num_samples
阅读全文
0 0
原创粉丝点击