neodroidvision.data.classification.mnist.MNISTDataset¶
- class neodroidvision.data.classification.mnist.MNISTDataset(data_dir: Path, split: SplitEnum = SplitEnum.training)[source]¶
Bases:
SupervisedDataset
description
Methods
__init__
(data_dir[, split])get_test_loader
(data_dir, batch_size, *[, ...])Test datalaoder.
get_train_valid_loader
(data_dir, *, ...[, ...])Train and validation data loaders.
plot_images
(images, label)- param images
sample
()description
Attributes
inverse_transform
return: :rtype:
return: :rtype:
return: :rtype:
trans
- static get_test_loader(data_dir: Path, batch_size: int, *, num_workers: int = 0, pin_memory: bool = False, using_cuda: bool = True) DataLoader [source]¶
Test datalaoder.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args: data_dir: path directory to the dataset. batch_size: how many samples per batch to load. num_workers: number of subprocesses to use when loading the dataset. pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
- Parameters
data_dir –
batch_size –
num_workers –
pin_memory –
using_cuda –
- static get_train_valid_loader(data_dir: Path, *, batch_size: int, random_seed: int, valid_size: float = 0.1, shuffle: bool = True, num_workers: int = 0, pin_memory: bool = False, using_cuda: bool = True) Tuple[DataLoader, DataLoader] [source]¶
Train and validation data loaders.
If using CUDA, num_workers should be set to 1 and pin_memory to True.
Args: data_dir: path directory to the dataset. batch_size: how many samples per batch to load. random_seed: fix seed for reproducibility. valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1]. In the paper, this number is set to 0.1.
shuffle: whether to shuffle the train/validation indices. show_sample: plot 9x9 sample grid of the dataset. num_workers: number of subprocesses to use when loading the dataset. pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU. :param data_dir: :type data_dir: :param batch_size: :type batch_size: :param random_seed: :type random_seed: :param valid_size: :type valid_size: :param shuffle: :type shuffle: :param num_workers: :type num_workers: :param pin_memory: :type pin_memory: :param using_cuda: :type using_cuda: