Source code for neodroidvision.utilities.torch_utilities.distributing.metrics
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 23/03/2020
"""
import typing
import torch
from draugr.writers import Writer
from torch import distributed
from neodroidvision.utilities.torch_utilities.distributing.distributing_utilities import (
global_world_size,
)
__all__ = ["write_metrics_recursive", "reduce_loss_dict"]
[docs]def write_metrics_recursive(
eval_result: typing.Mapping, prefix: str, summary_writer: Writer, global_step: int
) -> None:
"""
:param eval_result:
:param prefix:
:param summary_writer:
:param global_step:
"""
for key in eval_result:
value = eval_result[key]
tag = f"{prefix}/{key}"
if isinstance(value, typing.Mapping):
write_metrics_recursive(value, tag, summary_writer, global_step)
else:
summary_writer.scalar(tag, value, step_i=global_step)
[docs]def reduce_loss_dict(loss_dict: dict) -> dict:
"""
Reduce the loss dictionary from all processes so that process with rank
0 has the averaged results. Returns a dict with the same fields as
loss_dict, after reduction."""
world_size = global_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
loss_names = []
all_losses = []
for k in sorted(loss_dict.keys()):
loss_names.append(k)
all_losses.append(loss_dict[k])
all_losses = torch.stack(all_losses, dim=0)
distributed.reduce(all_losses, dst=0)
if distributed.get_rank() == 0:
# only main process gets accumulated, so only divide by world_size in this case
all_losses /= world_size
reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
return reduced_losses