Source code for neodroidvision.data.classification.deprec.s_vgg_face2

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 25/03/2020
           """

import csv
from pathlib import Path
from typing import Tuple

import torch
from PIL import Image
from draugr.numpy_utilities import SplitEnum
from draugr.visualisation import progress_bar
from matplotlib import pyplot
from torch.utils import data
from torchvision import transforms

__all__ = ["sVggFace2"]

from draugr.torch_utilities import SupervisedDataset

import pandas

N_IDENTITY = 9131  # total number of identities in VGG Face2
N_IDENTITY_PRETRAIN = 8631  # the number of identities used in training by Caffe


[docs]class sVggFace2(SupervisedDataset): """description""" @property def response_shape(self) -> Tuple[int, ...]: """ :return: :rtype:""" return (0,) @property def predictor_shape(self) -> Tuple[int, ...]: """ :return: :rtype:""" return self._resize_shape # mean = numpy.array([0.485, 0.456, 0.406]) # std = numpy.array([0.229, 0.224, 0.225]) inverse_transform = transforms.Compose( [ # transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()), transforms.ToPILImage() ] )
[docs] @staticmethod def get_id_label_map(meta_file): """ :param meta_file: :type meta_file: :return: :rtype:""" identity_list = meta_file df = pandas.read_csv( identity_list, sep=",\s+", quoting=csv.QUOTE_ALL, encoding="utf-8" ) df["class"] = -1 df.loc[df["Flag"] == 1, "class"] = range(N_IDENTITY_PRETRAIN) df.loc[df["Flag"] == 0, "class"] = range(N_IDENTITY_PRETRAIN, N_IDENTITY) key = df["Class_ID"].values val = df["class"].values id_label_dict = dict(zip(key, val)) return id_label_dict
@property def split_names(self) -> Tuple[str, str, str]: """ :return: :rtype:""" return { SplitEnum.training: "train", SplitEnum.validation: "validation", SplitEnum.testing: "test", }
[docs] def __init__( self, dataset_path: Path, split: SplitEnum = SplitEnum.training, resize_s: int = 256, raw_images: bool = False, ): """ :type resize_s: int or tuple(w,h) :param dataset_path: dataset directory :param split: train, valid, test""" super().__init__() assert dataset_path.exists(), f"root: {dataset_path} not found." split = self.split_names[split] self._resize_shape = (3, resize_s, resize_s) self._dataset_path = dataset_path / "data" / split image_list_file_path = dataset_path / "meta" / f"{split}_list.txt" assert ( image_list_file_path.exists() ), f"image_list_file: {image_list_file_path} not found." self._image_list_file_path = image_list_file_path meta_id_path = dataset_path / "meta" / "identity_meta.csv" assert meta_id_path.exists(), f"meta id path {meta_id_path} does not exists" assert resize_s > 2, "resize_s should be >2" self._split = split self._id_label_dict = self.get_id_label_map(meta_id_path) self._raw_images = raw_images self.train_trans = transforms.Compose( [ transforms.RandomResizedCrop(resize_s), transforms.RandomHorizontalFlip(), transforms.RandomRotation(180), transforms.ToTensor(), # transforms.Normalize(self.mean, self.std) ] ) self.val_trans = transforms.Compose( [ transforms.Resize(resize_s), # transforms.CenterCrop(resize_s), transforms.ToTensor(), # transforms.Normalize(self.mean, self.std) ] ) self._img_info = [] with open(str(self._image_list_file_path), "r") as f: for i, img_file in enumerate(f): img_file = img_file.strip() # e.g. n004332/0317_01.jpg class_id = img_file.split("/")[0] # like n004332 label = self._id_label_dict[class_id] self._img_info.append( {"class_id": class_id, "img": img_file, "label": label} ) if i % 1000 == 0: print(f"Processing: {i} images for {self._split} split")
def __len__(self): return len(self._img_info) def __getitem__(self, index): info = self._img_info[index] img_file = info["img"] img = Image.open(str(self._dataset_path / img_file)) if not self._raw_images: if self._split == SplitEnum.training: img = self.train_trans(img) else: img = self.val_trans(img) label = info["label"] class_id = info["class_id"] return img, label, img_file, class_id
if __name__ == "__main__": def main(p=Path(Path.home() / "Data" / "Datatsets" / "vggface2")): """ :param p: :type p: """ batch_size = 32 dt = sVggFace2( p, split=SplitEnum.testing, # raw_images=True ) test_loader = torch.utils.data.DataLoader( dt, batch_size=batch_size, shuffle=False ) # test_loader = dt for batch_idx, (imgs, label, img_files, class_ids) in progress_bar( enumerate(test_loader), total=len(test_loader), description="Bro", ncols=80, ): pyplot.imshow(dt.inverse_transform(imgs[0])) # pyplot.imshow(imgs) pyplot.show() break main()