Source code for neodroidvision.utilities.torch_utilities.distributing.distributing_utilities

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 01/03/2020
           """

import logging
import os
import sys
from pathlib import Path
from typing import Any, List

import torch
import torch.utils.data
from torch import distributed

from neodroidvision.utilities.torch_utilities.distributing.serialisation import (
    deserialise_byte_tensor,
    to_byte_tensor,
)

__all__ = [
    "all_gather_cuda",
    "reduce_dict",
    "setup_for_distributed",
    "is_distribution_ready",
    "is_main_process",
    "init_distributed_mode",
    "save_on_master",
    "global_distribution_rank",
    "global_world_size",
    "set_benchmark_device_dist",
    "synchronise_torch_barrier",
    "setup_distributed_logger",
]


[docs]def is_distribution_ready() -> bool: """ :return:""" if not distributed.is_available(): return False if not distributed.is_initialized(): return False return True
[docs]def global_world_size() -> int: """ :return:""" if not is_distribution_ready(): return 1 return distributed.get_world_size()
[docs]def global_distribution_rank() -> int: """ :return:""" if not is_distribution_ready(): return 0 return distributed.get_rank()
[docs]def is_main_process() -> bool: """ :return:""" return global_distribution_rank() == 0
[docs]def save_on_master(*args, **kwargs) -> None: """ :param args: :param kwargs: :return:""" if is_main_process(): torch.save(*args, **kwargs)
[docs]def setup_for_distributed(is_master: bool) -> None: """ This function disables printing when not in master process""" import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): """ Args: *args: **kwargs: """ force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print
[docs]def all_gather_cuda(data: Any) -> List[bytes]: """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank""" world_size = global_world_size() if world_size == 1: return [data] tensor = to_byte_tensor(data, device="cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] distributed.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: padding = torch.empty( size=(max_size - local_size,), dtype=torch.uint8, device="cuda" ) tensor = torch.cat((tensor, padding), dim=0) distributed.all_gather(tensor_list, tensor) return deserialise_byte_tensor(size_list, tensor_list)
[docs]def reduce_dict(input_dict: dict, average: bool = True) -> dict: """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction.""" world_size = global_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) distributed.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict
[docs]def init_distributed_mode(args: Any) -> None: """ Args: args: Returns: """ if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) args.gpu = int(os.environ["LOCAL_RANK"]) elif "SLURM_PROCID" in os.environ: args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, ) torch.distributed.barrier() setup_for_distributed(args.rank == 0)
[docs]def synchronise_torch_barrier() -> None: """ Helper function to synchronize (barrier) among all processes when using distributed training""" if not is_distribution_ready(): return world_size = distributed.get_world_size() if world_size == 1: return distributed.barrier()
[docs]def setup_distributed_logger( name: str, distributed_rank: int, save_dir: Path = None ) -> logging.Logger: """ :param name: :param distributed_rank: :param save_dir: :return:""" logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) # don't log results for the non-master process if distributed_rank > 0: return logger stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) if save_dir: fh = logging.FileHandler(str(save_dir / "log.txt")) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) return logger
[docs]def set_benchmark_device_dist(distributed: bool, local_rank: int) -> None: """ :param distributed: :param local_rank: :return:""" if torch.cuda.is_available(): # This flag allows you to enable the inbuilt cudnn auto-tuner to # find the best algorithm to use for your hardware. torch.backends.cudnn.benchmark = True if distributed: torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronise_torch_barrier()