Source code for neodroidvision.data.detection.coco.coco

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 22/03/2020
           """

from pathlib import Path
from typing import Tuple

import numpy
from PIL import Image
from draugr.numpy_utilities import SplitEnum
from draugr.opencv_utilities import xywh_to_minmax

from neodroidvision.data.detection.object_detection_dataset import (
    ObjectDetectionDataset,
)

__all__ = ["COCODataset"]

from draugr.torch_utilities import NamedTensorTuple


[docs]class COCODataset(ObjectDetectionDataset): """ COCO 2014 """ @property def response_shape(self) -> Tuple[int, ...]: """ Args: self: Returns: """ return (len(COCODataset.categories),) categories = ( "__background__", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ) image_dirs = { "coco_2014_valminusminival": "images/val2014", "coco_2014_minival": "images/val2014", "coco_2014_train": "images/train2014", "coco_2014_val": "images/val2014", } splits = { "coco_2014_valminusminival": "annotations/instances_valminusminival2014.json", "coco_2014_minival": "annotations/instances_minival2014.json", "coco_2014_train": "annotations/instances_train2014.json", "coco_2014_val": "annotations/instances_val2014.json", } @property def predictor_shape(self) -> Tuple[int, ...]: """description""" return -1, -1, 3
[docs] def __init__( self, data_root: Path, dataset_name: str, split: SplitEnum, img_transform: callable = None, annotation_transform: callable = None, ): """ :param data_root: :type data_root: :param dataset_name: :type dataset_name: :param split: :type split: :param img_transform: :type img_transform: :param annotation_transform: :type annotation_transform: :param remove_empty: :type remove_empty:""" super().__init__( data_root, dataset_name, split, img_transform, annotation_transform ) from pycocotools.coco import COCO self._coco_source = COCO(str(data_root / self.splits[dataset_name])) self._image_dir = data_root / self.image_dirs[dataset_name] self._img_transforms = img_transform self._annotation_transforms = annotation_transform self._remove_empty = split == SplitEnum.training if self._remove_empty: self._ids = list( self._coco_source.imgToAnns.keys() ) # when training, images without annotations are removed. else: self._ids = list( self._coco_source.imgs.keys() ) # when testing, all images used. self._coco_id_to_contiguous_id = { coco_id: i + 1 for i, coco_id in enumerate(sorted(self._coco_source.getCatIds())) } self._contiguous_id_to_coco_id = { v: k for k, v in self._coco_id_to_contiguous_id.items() }
def __getitem__(self, index): image_id = self._ids[index] boxes, labels = self._get_annotation(image_id) image = self._read_image(image_id) if self._img_transforms: image, boxes, labels = self._img_transforms(image, boxes, labels) if self._annotation_transforms: boxes, labels = self._annotation_transforms(boxes, labels) return image, NamedTensorTuple(boxes=boxes, labels=labels), index
[docs] def get_annotation(self, index): """ :param index: :type index: :return: :rtype:""" image_id = self._ids[index] return image_id, self._get_annotation(image_id)
def __len__(self): return len(self._ids) def _get_annotation(self, image_id): ann = [ obj for obj in self._coco_source.loadAnns( self._coco_source.getAnnIds(imgIds=image_id) ) if obj["iscrowd"] == 0 ] # filter crowd annotations boxes = numpy.array( [xywh_to_minmax(obj["bbox"]) for obj in ann], numpy.float32 ).reshape((-1, 4)) labels = numpy.array( [self._coco_id_to_contiguous_id[obj["category_id"]] for obj in ann], numpy.int64, ).reshape((-1,)) keep = (boxes[:, 3] > boxes[:, 1]) & ( boxes[:, 2] > boxes[:, 0] ) # remove invalid boxes return boxes[keep], labels[keep]
[docs] def get_img_info(self, index): """ :param index: :type index: :return: :rtype:""" return self._coco_source.imgs[self._ids[index]]
def _read_image(self, image_id): return numpy.array( Image.open( self._image_dir / self._coco_source.loadImgs(image_id)[0]["file_name"] ).convert("RGB") )