pytorch使用(一)处理并加载自己的数据
来源:互联网 发布:三砖淘宝店铺 编辑:程序博客网 时间:2024/06/01 10:16
pytorch使用:目录
pytorch使用(一)数据处理
个人认为,数据处理或许是在完成一篇论文中最耗费时间的,特别是大多情况下,需要在很多个库上做实验。
pytorch官方支持很多库,使用torchvision来完成数据的处理,点这里可以看到支持的库并不是很多。在这里,我将结合一个实例说明如何使用pytorch来处理自己的数据,任务是一个分析双臂运动的,检测6个关节点的运动。输入是连续三帧的检测结果以及计算的光流,也就是$3*6+2*2=22$
张heatmap,输出是中间帧的检测结果,也就是6张heatmap。
把原始数据处理为模型使用的数据需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分别可以理解为数据处理格式的定义、数据处理和数据加载。
1. 数据预处理torchvision.transforms
pytorch使用torchvision.transforms实现数据的预处理,包括中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等,建议整体浏览一下这一部分的官方手册,非常有用,数据处理很方便。
先转换为张量,然后正则化:
import torchvision.transforms as transformstransform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])#img = transform(img)
2. 数据读取,构建Dataset子类
参考:http://blog.csdn.net/victoriaw/article/details/72356453
如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。我们的将数据列表放在train.txt
和test.txt
中,将不同类型的数据的路径放在path.txt
中,所以在类的init函数中有path_file和 list_file两个变量
在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是len和getitem:
- len返回数据集的大小
- getitem实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。
末尾有自己写的一个Dataset子类的定义文件。
3. 数据加载
torch.utils.data.DataLoader()
函数,合成数据并且提供迭代访问。主要由两部分组成:
- dataset(Dataset)。输入加载的数据,就是上面的MyDataset
的实现。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:
- batch-size。样本每个batch的大小,默认为1。- shuffle。是否打乱数据,默认为False。- num_workers。数据分为几个线程处理默认为0。- sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False
使用:
import torchfrom datagen import MyDatasettrainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
以下是定义class MyDataset
文件datagen.py
, 其中有__init__(self, path_file, list_file,numJoints,type)
、__getitem__(self, idx)
、__len__(self)
三个函数,__getitem__
返回一个(22,256,256)的输入和一个(6,256,256)的标签。
'''Load data'''import numpy as npfrom PIL import Image#import cv2import torchimport torch.utils.data as dataimport torchvision.transforms as transformsclass MyDataset(data.Dataset): def __init__(self, path_file, list_file,numJoints,type): ''' Args: path_file: (str) heatmap and optical file location list_file: (str) path to index file. numJoints: (int) number of joints type: (boolean) use pose flow(true) or optical flow(false) ''' self.numJoints = numJoints # read heatmap and optical path with open(path_file) as f: paths = f.readlines() for path in paths: splited = path.strip().split() if splited[0]=='resPath': self.resPath = splited[1] elif splited[0]=='gtPath': self.gtPath = splited[1] elif splited[0]=='opticalFlowPath': self.opticalFlowPath = splited[1] elif splited[0]=='poseFlowPath': self.poseFlowPath = splited[1] if type: self.flowPath = self.poseFlowPath else: self.flowPath = self.opticalFlowPath #read list with open(list_file) as f: self.list = f.readlines() self.num_samples = len(self.list)def __getitem__(self, idx): ''' load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output :param idx: (int) image index :return: input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image output: the ground truth ''' input = [] output = [] # load heatmaps of 3 image for im in range(3): for map in range(6): curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp' heatmap = Image.open(curResPath) heatmap.load() heatmap = np.asarray(heatmap, dtype='float') / 255 input.append(heatmap) # load 2 flow for flow in range(2): curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg' flowX = Image.open(curFlowXPath) flowX.load() flowX = np.asarray(flowX, dtype='float') curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg' flowY = Image.open(curFlowYPath) flowY.load() flowY = np.asarray(flowY, dtype='float') input.append(flowX) input.append(flowY) # load groundtruth for map in range(6): curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp' heatmap = Image.open(curResPath) heatmap.load() heatmap = np.asarray(heatmap, dtype='float') / 255 output.append(heatmap) input = torch.Tensor(input) output = torch.Tensor(output) return input,outputdef __len__(self): return self.num_samples
- pytorch使用(一)处理并加载自己的数据
- 使用pytorch准备自己的数据
- pytorch学习1:如何加载自己的训练数据
- PyTorch学习系列(一)——加载数据并生成batch数据
- Pytorch学习笔记(二)自己加载单通道图片用作数据集训练
- Pytorch学习笔记(一):pytorch的安装-Ubuntu14.04
- pytorch 模型的加载
- pytorch: 准备、训练和测试自己的图片数据
- 文章标题 faster rcnn-pytorch版训练自己的数据
- Pytorch入门学习(七)---- 数据加载和预处理的通用方法
- tensorflow处理自己的图像数据(不使用队列)
- PyTorch读取Cifar数据集并显示图片(转载)
- PyTorch(四)——视频数据的处理
- PyTorch关于RNN序列数据的pack_pad处理
- pytorch自定义Dataset并使用torchvision的Transform
- PyTorch从零开始(一):
- Pytorch小记(一)
- PyTorch: Data Loading and Processing Tutorial 数据加载和处理教程
- 用JNDI实现数据库连接池
- 如何理解人工智能、机器学习和深度学习
- ArrayList和LinkedList的区别
- 数据结构实验之数组一:矩阵转置
- 《疯狂的程序员》经典语录
- pytorch使用(一)处理并加载自己的数据
- Lintcode176 Route Between Two Nodes in Graph solution 题解
- 给大家分享一下避免MySQL替换逻辑SQL的坑爹操作(链接)
- 【哈尔滨理工大学第七届程序设计竞赛初赛(高年级组)】 A B C D F G H I
- 初识软件工程
- 使用spring实现读写分离(mysql主从复制)五:一主多从的实现
- LeetCode 100.Same Tree
- 证书的格式以及证书的知识点
- 单机版hadoop搭建