Source code for neodroidvision.data.classification.nlet.triplet_dataset
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 30/06/2020
"""
import random
from pathlib import Path
from typing import Tuple
import numpy
import torch
__all__ = ["TripletDataset"]
from draugr.torch_utilities import global_pin_memory
from torch.utils.data import DataLoader
from neodroidvision.data.classification.nlet import PairDataset
[docs]class TripletDataset(
PairDataset
): # TODO: Extract image specificity of class to a subclass and move this super pair class to a
# general torch lib.
"""
# This dataset generates a triple of images. an image of a category, another of the same category and
lastly one from another category"""
def __getitem__(self, idx1: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
returns torch.tensors for img triplet, first tensor being idx random category, second being the same
category with different index
and third being of a random other category(Never the same)
:param idx1:
:type idx1:
:return:
:rtype:"""
t1 = random.choice(self._dataset.category_names)
while True:
idx2 = random.randint(0, self._dataset.category_sizes[t1])
if idx1 != idx2:
break
while True:
t2 = random.choice(self._dataset.category_names)
if t1 != t2:
break
return (
self._dataset.sample(t1, idx1)[0],
self._dataset.sample(t1, idx2)[0],
self._dataset.sample(
t2, random.randint(0, self._dataset.category_sizes[t2])
)[0],
*((t1, t2, t1) if self.return_categories else ()),
)
[docs] def sample(self, horizontal_merge: bool = False) -> None:
"""description"""
dl = iter(
DataLoader(
self,
batch_size=9,
shuffle=True,
num_workers=0,
pin_memory=global_pin_memory(0),
)
)
for _ in range(3):
images1, images2, images3, *labels = next(dl)
X1 = numpy.transpose(images1.numpy(), [0, 2, 3, 1])
X2 = numpy.transpose(images2.numpy(), [0, 2, 3, 1])
X3 = numpy.transpose(images3.numpy(), [0, 2, 3, 1])
if horizontal_merge:
X = numpy.dstack((X1, X2, X3))
else:
X = numpy.hstack((X1, X2, X3))
PairDataset.plot_images(X, list(zip(*labels)))
if __name__ == "__main__":
sd = TripletDataset(Path.home() / "Data" / "mnist_png", return_categories=True)
print(sd.predictor_shape)
print(sd.response_shape)
sd.sample()