pytorch学习-数据可视化

来源:互联网 发布:python微博爬虫实战 编辑:程序博客网 时间:2024/06/07 18:49

pytorch学习-数据可视化

可视化数据集图片,plt.imshow()和PIL.Image.show()
1、plt.imshow()
格式:rgb图片:rows*cols*channels,值在0-1之间(np)
示例:

# -*- coding:utf-8 -*-import numpy as npfrom torchvision import datasets, models, transformsimport matplotlib.pyplot as pltimport osimport os.pathplt.ion()   # interactive mode# Data augmentation and normalization for training# Just normalization for validation# 迁移模型--格式转换data_transforms = {    'train': transforms.Compose([        transforms.RandomSizedCrop(224),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        transforms.Normalize([0.485, 0.456, 0.406],                         [0.229, 0.224, 0.225])]),    'val': transforms.Compose([        transforms.Scale(256),        transforms.CenterCrop(224),        transforms.ToTensor(),        transforms.Normalize([0.485, 0.456, 0.406],                        [0.229, 0.224, 0.225])]),}data_dir = 'hymenoptera_data'# 注意ImageFolder对于train和test的用法image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])                  for x in ['train', 'val']}# batch_size = 4 则绘制出四幅图像dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size=1,shuffle=True, num_workers=4)              for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x])               for x in ['train', 'val']}class_names = image_datasets['train'].classesuse_gpu = torch.cuda.is_available()#Visualize a few images#plt.imshow 传入的是rgb图片:rows*cols*channels,值在0-1之间(np)#PIL.Image.show()传入的是0-255的RGB图片格式def imshow(inp, title=None):    """Imshow for Tensor."""    inp = inp.numpy().transpose((1, 2, 0))    # 转变数组格式 RGB图像格式:rows*cols*channels    # (3,228,906)   #(228,906,3)    mean = np.array([0.485, 0.456, 0.406])    std = np.array([0.229, 0.224, 0.225])    inp = std * inp + mean    # 去标准化,对应transforms    inp = np.clip(inp, 0, 1)    # 修正 clip 限制inp的值,小于0则=0,大于1则=1    plt.imshow(inp)    if title is not None:        plt.title(title)    plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training datainputs, classes = next(iter(dataloders['train']))  # inputs:[torch.FloatTensor of size 4x3x224x224](batch_size x3x224x224)# Make a grid from batchout = torchvision.utils.make_grid(inputs)# out:[torch.FloatTensor of size 3x228x906(batch_size = 5   3x228x1132])imshow(out, title=[class_names[x] for x in classes])

2、PIL.Image.show()
格式:#RGB模式:rows*cols*channels,值:0-255,PIL图片对象
实例:

from PIL import Imageimport numpy as npimg = Image.open("img.jpg")img.show()# img = np.array(img)# print(img) # 输出为0-255的numpy.array
原创粉丝点击