Source code for neodroidvision.data.detection.multi_dataset

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 22/03/2020
           """

from abc import abstractmethod
from pathlib import Path
from typing import Sequence, Tuple

from draugr.numpy_utilities import SplitEnum
from torch.utils.data import ConcatDataset

__all__ = ["MultiDataset"]

from draugr.torch_utilities import SupervisedDataset
from neodroidvision.detection.single_stage.ssd.bounding_boxes.ssd_transforms import (
    SSDTransform,
    SSDAnnotationTransform,
)


[docs]class MultiDataset(SupervisedDataset): """description""" @property @abstractmethod def categories(self) -> Sequence: """description""" raise NotImplementedError @property def response_shape(self) -> Tuple[int, ...]: """ :return: :rtype:""" return (len(self.categories),)
[docs] def __init__( self, *, cfg, dataset_type: callable, data_root: Path, sub_datasets: Tuple, split: SplitEnum = SplitEnum.training, ): """ :param data_root: :type data_root: :param sub_datasets: :type sub_datasets: :param transform: :type transform: :param target_transform: :type target_transform: :param split: :type split: :return: :rtype:""" super().__init__() assert len(sub_datasets) > 0, "No data found!" img_transform = SSDTransform( image_size=cfg.input.image_size, pixel_mean=cfg.input.pixel_mean, split=split, ) if split == SplitEnum.training: annotation_transform = SSDAnnotationTransform( image_size=cfg.input.image_size, priors_cfg=cfg.model.box_head.priors, center_variance=cfg.model.box_head.center_variance, size_variance=cfg.model.box_head.size_variance, iou_threshold=cfg.model.box_head.iou_threshold, ) else: annotation_transform = None datasets = [] for dataset_name in sub_datasets: datasets.append( dataset_type( data_root=data_root, dataset_name=dataset_name, split=split, img_transform=img_transform, annotation_transform=annotation_transform, ) ) # for testing, return a list of datasets if not split == SplitEnum.training: self.sub_datasets = datasets else: dataset = datasets[0] if len(datasets) > 1: dataset = ConcatDataset(datasets) self.sub_datasets = [dataset]