Source code for neodroidvision.data.neodroid_environments.classification.data

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Tuple, Union

import neodroid
from PIL import Image
from draugr.multiprocessing_utilities import PooledQueueProcessor, PooledQueueTask
from draugr.torch_utilities import global_torch_device
from torch.utils.data import Dataset
from torchvision import transforms

__author__ = "Christian Heider Nielsen"

import torch

default_torch_transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(180),
        # transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

default_torch_retransform = transforms.Compose([transforms.ToPILImage("RGB")])

__all__ = [
    "neodroid_env_classification_generator",
    "pooled_neodroid_env_classification_generator",
]


[docs]def neodroid_env_classification_generator(env, batch_size=64) -> Tuple: """ :param env: :param batch_size: """ while True: predictors = [] class_responses = [] while len(predictors) < batch_size: state = env.update() rgb_arr = state.sensor("RGB").value rgb_arr = Image.open(rgb_arr).convert("RGB") a_class = state.sensor("Class").value predictors.append(default_torch_transform(rgb_arr)) class_responses.append(int(a_class)) a = torch.stack(predictors).to(global_torch_device()) b = torch.LongTensor(class_responses).to(global_torch_device()) yield a, b
[docs]def pooled_neodroid_env_classification_generator(env, device, batch_size=64) -> Tuple: """ :param env: :param device: :param batch_size: :return: """ class FetchConvert(PooledQueueTask): """description""" def __init__( self, env, device: Union[str, torch.device] = "cpu", batch_size: int = 64, *args, **kwargs ): """ :param env: :param device: :param batch_size: :param args: :param kwargs:""" super().__init__(*args, **kwargs) self.env = env self.batch_size = batch_size self.device = device def call(self, *args, **kwargs) -> Tuple: """ Args: *args: **kwargs: Returns: """ predictors = [] class_responses = [] while len(predictors) < self.batch_size: state = self.env.update() rgb_arr = state.sensor("RGB").value rgb_arr = Image.open(rgb_arr).convert("RGB") a_class = state.sensor("Class").value predictors.append(default_torch_transform(rgb_arr)) class_responses.append(int(a_class)) return ( torch.stack(predictors).to(self.device), torch.LongTensor(class_responses).to(self.device), ) task = FetchConvert(env, device=device, batch_size=batch_size) processor = PooledQueueProcessor( task, fill_at_construction=True, max_queue_size=16, n_proc=None ) for a in zip(processor): yield a
if __name__ == "__main__": def asdadsad(): """description""" neodroid_generator = neodroid_env_classification_generator(neodroid.connect()) train_loader = torch.utils.data.DataLoader( dataset=neodroid_generator, batch_size=12, shuffle=True ) for p, r in train_loader: print(r) asdadsad()