Source code for neodroidvision.data.classification.imagenet.imagenet_2012

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 27/06/2020
           """

from pathlib import Path
from typing import Dict, Tuple

import numpy
import torch
import torchvision
from draugr.numpy_utilities import SplitEnum
from draugr.torch_utilities import SupervisedDataset
from draugr.visualisation import progress_bar
from matplotlib import pyplot
from torch.utils import data
from torchvision import transforms

from neodroidvision.data.classification.imagenet.imagenet_2012_id import categories_id
from neodroidvision.data.classification.imagenet.imagenet_2012_names import (
    categories_names,
)

__all__ = ["ImageNet2012"]


[docs]class ImageNet2012(SupervisedDataset): """description""" mean = numpy.array([0.485, 0.456, 0.406]) std = numpy.array([0.229, 0.224, 0.225]) category_names = categories_names category_id = categories_id inverse_base_transform = transforms.Compose( [ transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()), transforms.ToPILImage(), ] ) base_transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean, std)] ) @property def response_shape(self) -> Tuple[int, ...]: """ :return: :rtype:""" return (1000,) @property def predictor_shape(self) -> Tuple[int, ...]: """ :return: :rtype:""" return self._crop_size, self._crop_size @property def split_names(self) -> Dict[SplitEnum, str]: """ :return: :rtype:""" return { SplitEnum.training: "train", SplitEnum.validation: "val", SplitEnum.testing: "test", }
[docs] def __init__( self, dataset_path: Path, split: SplitEnum = SplitEnum.training, resize_s: int = 256, crop_size: int = 224, ): """ :type resize_s: int or tuple(w,h) :param dataset_path: dataset directory :param split: train, valid, test""" super().__init__() if isinstance(dataset_path, str): dataset_path = Path(dataset_path) assert dataset_path.exists(), f"root: {dataset_path} not found." assert resize_s > 2, "resize_s should be >2" assert crop_size > 2, "crop_size should be >2" self._crop_size = crop_size self._split = split self._dataset_path = dataset_path / self.split_names[split] self.train_trans = transforms.Compose( [ transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip(), self.base_transform, ] ) self.val_trans = transforms.Compose( [ transforms.Resize(resize_s), transforms.CenterCrop(crop_size), self.base_transform, ] ) self._image_folder = torchvision.datasets.ImageFolder( str(self._dataset_path), self.val_trans )
def __len__(self) -> int: return len(self._image_folder) def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]: """ return image and category :param index: :type index: :return: :rtype:""" return self._image_folder[index]
if __name__ == "__main__": def main(): """description""" batch_size = 32 dt = ImageNet2012( Path.home() / "Data" / "Datasets" / "ILSVRC2012", split=SplitEnum.validation ) val_loader = torch.utils.data.DataLoader( dt, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True ) for batch_idx, (imgs, categories) in progress_bar( enumerate(val_loader), total=len(val_loader), description="Bro", ncols=80, ): pyplot.imshow(dt.inverse_base_transform(imgs[0])) pyplot.title(dt.category_names[categories[0].item()]) pyplot.show() break main()