使用pytorch准备自己的数据
来源:互联网 发布:电子相册的软件 编辑:程序博客网 时间:2024/06/08 03:06
前言
对于著名的数据集比如mnist,像Tensorflow、pytorch这样的流行框架已把它们集成到相关模块中,使用时一至几行简单的代码就能调用。但往往我们需要在自己的数据集上完成一些操作,这篇博客就旨在以单标签图像分类为例,浅谈一下如何使用pytorch准备自己的数据,如有错误,敬请斧正。
我所做的是一个室外图像的天气分类任务,类别只有sunny和cloudy两类。在这个例子中我们不需要提供额外的txt或其他形式的文件来将图片和标签对应起来,但需要将数据集按以下结构组织起来。
训练集和验证集(当然还可以有测试集)需要分开,每个split下面各个类别的图片也要分开,并且文件夹的名字最好就是类别名称。(关于使用pytorch进行回归或多标签分类任务本人还未研究过,这里就暂时不作介绍了)
下面就开始讲代码了,首先把全部代码贴出来,然后再细致解释一下。
import torchimport torchvisionfrom torchvision import datasets, transformsimport matplotlib.pyplot as plt import numpy as npimport os# Data augmentation and normalization for training # Just normalization for validationdata_transforms = { 'train': transforms.Compose([ transforms.RandomSizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),}data_dir = '/mount/temp/WZG/pytorch/Data/'train_sets = datasets.ImageFolder(os.path.join(data_dir, 'train'), data_transforms['train'])train_loader = torch.utils.data.DataLoader(train_sets, batch_size=10, shuffle=True, num_workers=4)train_size = len(train_sets)train_classes = train_sets.classesval_sets = datasets.ImageFolder(os.path.join(data_dir, 'val'), data_transforms['val'])val_loader = torch.utils.data.DataLoader(val_sets, batch_size=10, shuffle=False, num_workers=4)val_size = len(val_sets)# Visualize a few imagesdef imshow(inp, title=None): """Imshow for Tensor.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean plt.imshow(inp) if title is not None: plt.title(title)inputs, classes = next(iter(train_loader))# Make a grid from batchout = torchvision.utils.make_grid(inputs, nrow=5)imshow(out, title=[train_classes[x] for x in classes])
使用pytorch准备自己的数据时,主要用到的就是torch.utils.data模块以及torchvision的datasets和transforms模块,所以代码开始部分,首先把相关模块导入。
在代码的正文部分,首先定义对图像的一些变换以达到数据增强的目的,函数基本都是见名知意的,如果读者了解深度学习,那么对这些变换一定不会陌生。除了数据增强,还有一步操作是必要的,那就是把图像转换成pytorch需要的tensor格式。pytorch里面的tensor和Numpy的ndarray很像(但绝不等价),pytorch的官网在做介绍时,很多时候会和Numpy进行联系和对比,而tensor和ndarray也可以通过调用相关函数进行相互转换。数据转换成tensor后,数值范围会被自动压缩到0~1之间。这份代码中之后还使用transforms.Normalize()函数对数据进行了归一化,该函数包含两个list类型的参数,第一个参数为RGB三个通道各自的均值,第二个参数为相应的方差。另外,代码中的transforms.Compose()函数的作用是把所有这些变换组合到一起。这个例子中,我们把对训练集和验证集的变换操作写到了一个字典里,不过也完全可以将它们分开来写。
接下来,我们使用datasets.ImageFolder()函数来创建dataset对象,该函数的第一个参数是一个路径(比如这个例子中的训练集的路径),第二个参数是对这个路径下的图片要进行的变换操作(我们刚刚定义的那些变换)。再然后就是使用torch.utils.data.DataLoader()函数来定义数据的加载方式了,例子中对该函数给了4个参数,第一个是刚刚创建的dataset对象,第二个是batch的大小(即一个batch包含的样本数量),第三个参数是一个布尔值,代表是否进行shuffle,训练的话一般都会设为True。第四个参数num_workers表示开启多少个子进程进行数据的读取(并行读取),默认是0,即只使用主进程读数据。其余的更多参数请查阅官网doc。
train_size = len(train_sets)train_classes = train_sets.classes
接下来的这两句是为了得到这个数据集的大小和所有的类别名称。得到的结果如下图所示:
可以看到我们创建的dataset对象的classes属性就对应着相关类别文件夹的名称。
为了测试数据是否能正确加载,我们定义一个imshow()函数来展示数据。imshow()函数中,首先对数据的维度进行transpose操作,这是因为tensor中,图片的shape是先通道再宽/高,如下:
而显示图像需要先宽/高再通道。到这里读者应该有个疑问,那就是为什么batch的维度是4,而transpose函数里的shape却是3维的。这其实跟我们的显示形式有关,我后面会马上讲到。transpose之后就是归一化的逆操作了,最后使用plt.imshow()函数进行显示即可。
实际加载数据时,使用inputs, classes = next(iter(train_loader))一行代码就可以得到一个batch的数据,该函数会进行非重复采样,直至数据集被完整遍历一次。得到的inputs即一个batch的图像数据,而这里的classes是inputs各个样本对应的整数标签,自动从0开始,并且和类别名称的索引也是对应的。在这个例子中,0就对应cloudy,1对应sunny。再次把之前的结果贴一下,就更好理解了。
为了展示方便,我们使用下面这行代码对数据进行了一下处理:
out = torchvision.utils.make_grid(inputs, nrow=5)
这里的make_grid()函数就是把一个batch的数据重新排列成格的形式,例子中它的第一个参数即我们刚刚加载的batch,第二个参数代表一行放几个样本。处理之后,数据就变成了3维的,如下图所示:
这就是为什么自定义的imshow()函数中的参数也是3维的。如果细心一点,会发现原数据和处理之后的数据有对不上的地方。原数据的shape为(10,3,224,224),batch size为10,处理时,我设置的是一行显示5张图片,也就是总共2行5列。那么处理后数据应该是(3,448,1120)才对,可从结果来看,处理后数据的高和宽都变大了。这主要是make_grid()函数本身搞的鬼,可能是为了显示时把各幅图片区分开,该函数会在图片之间以及整个grid的边缘自动加上线宽为2个像素的黑线,这是我查看了数据具体数值后发现的,图片之间的部分会有宽度为2,值全为0的间隔,那么从纵向来看,2*3(2行图片,中间有一条间隔,加grid上下两条边)=6=454-448。横向按此方法计算也对得上。
最后显示出的结果如下图所示:
从结果来看,我们加载的数据应该是正确的,并且从title可以看出,数据也确实经过了shuffle。
- 使用pytorch准备自己的数据
- pytorch: 准备、训练和测试自己的图片数据
- pytorch使用(一)处理并加载自己的数据
- pytorch学习1:如何加载自己的训练数据
- 文章标题 faster rcnn-pytorch版训练自己的数据
- caffe准备自己的数据集
- pytorch如何自定义自己的MyDatasets
- PyTorch使用指定的GPU
- PyTorch(三)——使用训练好的模型测试自己图片
- Docker: 使用jupyter notebook基础镜像搭建自己的 pytorch 开发环境
- pytorch 使用
- 使用pytorch进行图像的顺序读取。
- pytorch pruning训练自己的数据库(流程+BUG调试)
- TensorFlow——训练自己的数据——CIFAR10(一)数据准备
- caffe训练自己的数据集——1. 数据准备
- 准备开始自己的博客
- 准备记录自己的生活
- 使用Listener准备application作用域数据的小问题
- TensorFlow Docker一览
- 用CSS进行网页布局
- USB之SE0、SE1
- 【Redis学习】:redis持久化
- c印记(十四):跨平台线程封装
- 使用pytorch准备自己的数据
- 简单实现一个自定义view的ProgressBar
- 关于Maven构建项目或者update项目时jdk变为1.5解决方案,亲测有效。
- Invoke BeginInvoke EndInvoked的使用 简单的线程同步
- 成功人的五个步骤
- 找到焦点onfocus和失去焦点onblur、以及onchange
- Python生成PASCAL VOC格式的xml标注文件
- 找到一本不错的Linux电子书,附《Linux就该这么学》章节目录。
- 找到一本不错的Linux电子书,附《Linux就该这么学》章节目录