Source code for neodroidvision.utilities.torch_utilities.batching.collation
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "heider"
__doc__ = r"""
Created on 5/5/22
"""
from typing import Iterable, Tuple
from draugr.torch_utilities import NamedTensorTuple
from torch.utils.data.dataloader import default_collate
__all__ = ["BatchCollator"]
[docs]class BatchCollator:
"""description"""
[docs] def __init__(self, wrap: bool = True):
self.wrap = wrap
def __call__(self, batch: Iterable) -> Tuple:
transposed_batch = list(zip(*batch))
images = default_collate(transposed_batch[0])
img_ids = default_collate(transposed_batch[2])
if self.wrap:
list_targets = transposed_batch[1]
targets = {
key: default_collate([d[key] for d in list_targets])
for key in list_targets[0]
}
targets = NamedTensorTuple(**targets)
else:
targets = None
return images, targets, img_ids