Source code for neodroidvision.utilities.torch_utilities.distributing.distributed
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "heider"
__doc__ = r"""
Created on 5/5/22
"""
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Code is copy-pasted exactly as in torch.utils.data.distributed.
# FIXME remove this once c10d fixes the bug it has
import math
from typing import Sized
import torch
from torch import distributed
from torch.utils.data.sampler import Sampler
__all__ = ["DistributedSampler"]
# from torchvision.datasets.samplers import DistributedSampler
[docs]class DistributedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas."""
[docs] def __init__(
self, dataset: Sized, num_replicas: int = None, rank=None, shuffle: bool = True
):
"""
:param dataset:
:param num_replicas:
:param rank:
:param shuffle:"""
if num_replicas is None:
if not distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = distributed.get_world_size()
if rank is None:
if not distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
offset = self.num_samples * self.rank
indices = indices[offset : offset + self.num_samples]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
[docs] def set_epoch(self, epoch):
"""
:param epoch:
:type epoch:
"""
self.epoch = epoch