#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 27/06/2020
"""
import csv
from pathlib import Path
from typing import Dict, Tuple
import torch
from PIL import Image
from draugr.numpy_utilities import SplitEnum
from matplotlib import pyplot
from torch.utils import data
from torchvision import transforms
__all__ = ["VggFace2"]
from draugr.torch_utilities import SupervisedDataset
[docs]class VggFace2(SupervisedDataset):
"""
Department of Engineering Science, University of Oxford
Visual Geometry Group Face 2 Dataset
"""
""" description """
@property
def response_shape(self) -> Tuple[int, ...]:
"""
:return:
:rtype:"""
return (0,)
@property
def predictor_shape(self) -> Tuple[int, ...]:
"""
:return:
:rtype:"""
return self._resize_shape
# mean = numpy.array([0.485, 0.456, 0.406])
# std = numpy.array([0.229, 0.224, 0.225])
inverse_transform = transforms.Compose(
[
# transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()),
transforms.ToPILImage()
]
)
[docs] @staticmethod
def get_id_label_map(meta_file: Path):
"""
:param meta_file:
:type meta_file:
:return:
:rtype:"""
import pandas
N_IDENTITY = 9131 # total number of identities in VGG Face2
N_IDENTITY_PRETRAIN = 8631 # the number of identities used in training by Caffe
identity_list = meta_file
df = pandas.read_csv(
identity_list,
sep=",\s+",
quoting=csv.QUOTE_ALL,
encoding="utf-8",
engine="python",
)
df["class"] = -1
df.loc[df["Flag"] == 1, "class"] = range(N_IDENTITY_PRETRAIN)
df.loc[df["Flag"] == 0, "class"] = range(N_IDENTITY_PRETRAIN, N_IDENTITY)
key = df["Class_ID"].values
val = df["class"].values
id_label_dict = dict(zip(key, val))
return id_label_dict
@property
def split_names(self) -> Dict[SplitEnum, str]:
"""
:return:
:rtype:"""
return {
SplitEnum.training: "train",
SplitEnum.validation: "validation",
SplitEnum.testing: "test",
}
[docs] def __init__(
self,
dataset_path: Path,
split: SplitEnum = SplitEnum.training,
*,
resize_s: int = 256,
raw_images: bool = False,
verbose: bool = False,
):
"""
:type resize_s: int or tuple(w,h)
:param dataset_path: dataset directory
:param split: train, valid, test"""
super().__init__()
assert dataset_path.exists(), f"root: {dataset_path} not found."
split = self.split_names[split]
if isinstance(resize_s, int):
assert resize_s > 2, "resize_s should be >2"
resize_s = (resize_s, resize_s, 3)
self._resize_shape = (*resize_s, 3)
self._dataset_path = dataset_path / split
image_list_file_path = dataset_path / f"{split}_list.txt"
assert (
image_list_file_path.exists()
), f"image_list_file: {image_list_file_path} not found."
self._image_list_file_path = image_list_file_path
meta_id_path = dataset_path / "identity_meta.csv"
if not meta_id_path.exists():
meta_id_path = dataset_path.parent / "meta" / meta_id_path.name
assert meta_id_path.exists(), f"meta id path {meta_id_path} does not exists"
self._split = split
self._id_label_dict = self.get_id_label_map(meta_id_path)
self._return_raw_images = raw_images
self.train_trans = transforms.Compose(
[
transforms.RandomResizedCrop(self._resize_shape[:2]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize(self.mean, self.std)
]
)
self.val_trans = transforms.Compose(
[
transforms.Resize(self._resize_shape[:2]),
transforms.CenterCrop(self._resize_shape[:2]),
transforms.ToTensor(),
# transforms.Normalize(self.mean, self.std)
]
)
self._img_info = []
with open(str(self._image_list_file_path), "r") as f:
for i, img_file in enumerate(f):
img_file = img_file.strip() # e.g. n004332/0317_01.jpg
class_id = img_file.split("/")[0] # like n004332
label = self._id_label_dict[class_id]
self._img_info.append(
{"class_id": class_id, "img": img_file, "label": label}
)
if verbose and i % 1000 == 0:
print(f"Processing: {i} images for {self._split} split")
def __len__(self):
return len(self._img_info)
def __getitem__(self, index):
info = self._img_info[index]
img_file = info["img"]
img = Image.open(str(self._dataset_path / img_file))
if not self._return_raw_images:
if self._split == SplitEnum.training:
img = self.train_trans(img)
else:
img = self.val_trans(img)
label = info["label"]
class_id = info["class_id"]
return img, label, img_file, class_id
if __name__ == "__main__":
def main():
"""description"""
batch_size = 32
dt = VggFace2(
Path.home() / "Data" / "VGG-Face2" / "data",
split=SplitEnum.testing,
# raw_images=True
)
test_loader = torch.utils.data.DataLoader(
dt, batch_size=batch_size, shuffle=False
)
from draugr.visualisation import progress_bar
# test_loader = dt
for batch_idx, (imgs, label, img_files, class_ids) in progress_bar(
enumerate(test_loader),
total=len(test_loader),
description=f"{test_loader.dataset}",
ncols=80,
):
pyplot.imshow(dt.inverse_transform(imgs[0]))
pyplot.title(f"{label[0], class_ids[0]}")
# pyplot.imshow(imgs)
pyplot.show()
break
main()