Source code for neodroidvision.data.classification.dict_image_folder
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 28/07/2020
"""
__all__ = ["SplitDictImageFolder", "DictImageFolder"]
from pathlib import Path
from draugr.numpy_utilities import SplitEnum
from draugr.torch_utilities import (
DictDatasetFolder,
SplitDictDatasetFolder,
)
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader
from torchvision.transforms import transforms
[docs]class SplitDictImageFolder(SplitDictDatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
imgs (list): List of (image path, class_index) tuples"""
[docs] def __init__(
self,
root: Path,
transform: callable = transforms.ToTensor(),
target_transform: callable = None,
loader=default_loader,
split: SplitEnum = SplitEnum.training,
):
super().__init__(
root,
loader,
extensions=IMG_EXTENSIONS,
transform=transform,
target_transform=target_transform,
split=split,
)
self.imgs = self._data_categories
[docs]class DictImageFolder(DictDatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
imgs (list): List of (image path, class_index) tuples"""
[docs] def __init__(
self,
root: Path,
transform: callable = transforms.ToTensor(),
target_transform: callable = None,
loader=default_loader,
is_valid_file: callable = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self._data