PyTorch代码学习-torchvision.datasets中folder.py

来源:互联网 发布:怎样看待网络流行语 编辑:程序博客网 时间:2024/06/07 01:08

PyTorch代码学习-torchvision.datasets中folder.py

文章说明:本人学习pytorch/vision/torchvision/datasets/folder.py理解(待续)

理解:

ImageFolder :转化为torch可识别的dataset格式,可被dataloader包装文件夹格式:Root/dog/imgclass ImageFolder(data.Dataset): # 继承data.Dataset        def __init__(self):             # 初始化属性和参数            self.name = name # 可在整个类使用            计算classes            计算self.imgs # (图片路径,图片类别)    def  __getitem__(self, index):        返回可索引的数据集格式        返回(图片格式,图片类别)    def __len__(self):        返回数据集的大小    重点函数:        1for root, _, fnames in sorted(os.walk(d)):        # os.walk:遍历目录下所有内容,产生三元组        # (dirpath, dirnames, filenames)【文件夹路径, 文件夹名字, 文件名】        2)注意:图片路径 => 图片格式        3)图片类别的文件名(str)=> 类别名称

代码:

import torch.utils.data as data#PIL: Python Image Library缩写,图像处理模块#     Image,ImageFont,ImageDraw,ImageFilterfrom PIL import Image    import osimport os.path# 图片扩展(图片格式)IMG_EXTENSIONS = [    '.jpg', '.JPG', '.jpeg', '.JPEG',    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',]# 判断是不是图片文件def is_image_file(filename):    # 只要文件以IMG_EXTENSIONS结尾,就是图片    # 注意any的使用    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)# 结果:classes:['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']# classes_to_idx:{'1': 1, '0': 0, '3': 3, '2': 2, '5': 5, '4': 4, '7': 7, '6': 6, '9': 9, '8': 8}def find_classes(dir):    '''     返回dir下的类别名,classes:所有的类别,class_to_idx:将文件中str的类别名转化为int类别     classes为目录下所有文件夹名字的集合    '''    # os.listdir:以列表的形式显示当前目录下的所有文件名和目录名,但不会区分文件和目录。    # os.path.isdir:判定对象是否是目录,是则返回True,否则返回False    # os.path.join:连接目录和文件名    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]    # sort:排序    classes.sort()    # 将文件名中得到的类别转化为数字class_to_idx['3'] = 3    class_to_idx = {classes[i]: i for i in range(len(classes))}    return classes, class_to_idx    # class_to_idx :{'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}# 如果文件是图片文件,则保留它的路径,和索引至images(path,class_to_idx)def make_dataset(dir, class_to_idx):    # 返回(图片的路径,图片的类别)    # 打开文件夹,一个个索引    images = []    # os.path.expanduser(path):把path中包含的"~"和"~user"转换成用户目录    dir = os.path.expanduser(dir)    for target in sorted(os.listdir(dir)):        d = os.path.join(dir, target)        if not os.path.isdir(d):            continue        # os.walk:遍历目录下所有内容,产生三元组        # (dirpath, dirnames, filenames)【文件夹路径, 文件夹名字, 文件名】        for root, _, fnames in sorted(os.walk(d)):            for fname in sorted(fnames):                if is_image_file(fname):                    path = os.path.join(root, fname)   # 图片的路径                    item = (path, class_to_idx[target])  # (图片的路径,图片类别)                    images.append(item)    return images# 打开路径下的图片,并转化为RGB模式def pil_loader(path):    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)    # with as : 安全方面,可替换:try,finally    # 'r':以读方式打开文件,可读取文件信息    # 'b':以二进制模式打开文件,而不是文本    with open(path, 'rb') as f:        with Image.open(f) as img:            # convert:,用于图像不同模式图像之间的转换,这里转换为‘RGB’            return img.convert('RGB')def accimage_loader(path):    # accimge:高性能图像加载和增强程序模拟的程序。    import accimage    try:        return accimage.Image(path)    except IOError:        # Potentially a decoding problem, fall back to PIL.Image        return pil_loader(path)def default_loader(path):    # get_image_backend:获取加载图像的包的名称    from torchvision import get_image_backend    if get_image_backend() == 'accimage':        return accimage_loader(path)    else:        return pil_loader(path)class ImageFolder(data.Dataset):    """A generic data loader where the images are arranged in this way: ::        root/dog/xxx.png        root/dog/xxy.png        root/dog/xxz.png        root/cat/123.png        root/cat/nsdf3.png        root/cat/asd932_.png    Args:        root (string): Root directory path.        transform (callable, optional): A function/transform that  takes in an PIL image            and returns a transformed version. E.g, ``transforms.RandomCrop``        target_transform (callable, optional): A function/transform that takes in the            target and transforms it.        loader (callable, optional): A function to load an image given its path.     Attributes:        classes (list): List of the class names.        class_to_idx (dict): Dict with items (class_name, class_index).        imgs (list): List of (image path, class_index) tuples    """    # 初始化,继承参数    def __init__(self, root, transform=None, target_transform=None,                 loader=default_loader):        # TODO        # 1. Initialize file path or list of file names.        # 找到root的文件和索引        classes, class_to_idx = find_classes(root)        # 保存路径下图片文件路径和索引至imgs        imgs = make_dataset(root, class_to_idx)        if len(imgs) == 0:            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))        self.root = root        self.imgs = imgs        self.classes = classes        self.class_to_idx = class_to_idx        self.transform = transform        self.target_transform = target_transform        self.loader = loader    def __getitem__(self, index):        """        Args:            index (int): Index        Returns:            tuple: (image, target) where target is class_index of the target class.        """        # TODO        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).        # 2. Preprocess the data (e.g. torchvision.Transform).        # 3. Return a data pair (e.g. image and label).        #这里需要注意的是,第一步:read one data,是一个data        path, target = self.imgs[index]         # 这里返回的是图片路径,而需要的是图片格式        img = self.loader(path) # 将图片路径加载成所需图片格式        if self.transform is not None:            img = self.transform(img)        if self.target_transform is not None:            target = self.target_transform(target)        return img, target    def __len__(self):        # return the total size of your dataset.        return len(self.imgs)
原创粉丝点击