Source code for neodroidvision.segmentation.evaluation.experimental.cross_entropy_2d

import torch
from torch.nn.functional import cross_entropy, interpolate

__all__ = [
    "cross_entropy2d",
    "multi_scale_cross_entropy2d",
    "bootstrapped_cross_entropy2d",
]


[docs]def cross_entropy2d(input, target, weight=None, size_average=True): """ Args: input: target: weight: size_average: Returns: """ n, c, h, w = input.size() nt, ht, wt = target.size() # Handle inconsistent size between input and target if h != ht and w != wt: # upsample labels input = interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) return cross_entropy( input.transpose(1, 2).transpose(2, 3).reshape(-1, c), target.reshape(-1), weight=weight, size_average=size_average, ignore_index=250, )
[docs]def multi_scale_cross_entropy2d( input, target, weight=None, size_average=True, scale_weight=None ): """ Args: input: target: weight: size_average: scale_weight: Returns: """ if not isinstance(input, tuple): return cross_entropy2d( input=input, target=target, weight=weight, size_average=size_average ) # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] if scale_weight is None: # scale_weight: torch tensor type n_inp = len(input) scale = 0.4 scale_weight = torch.pow( scale * torch.ones(n_inp), torch.arange(n_inp).float() ).to(input.device) loss = 0.0 for i, inp in enumerate(input): loss = loss + scale_weight[i] * cross_entropy2d( input=inp, target=target, weight=weight, size_average=size_average ) return loss
[docs]def bootstrapped_cross_entropy2d(input, target, K: int, weight=None, size_average=True): """ Args: input: target: K: weight: size_average: Returns: """ batch_size = input.size()[0] def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): n, c, h, w = input.size() loss = cross_entropy( input.transpose(1, 2).transpose(2, 3).reshape(-1, c), target.reshape(-1), weight=weight, reduce=False, size_average=False, ignore_index=250, ) topk_loss, _ = loss.topk(K) reduced_topk_loss = topk_loss.sum() / K return reduced_topk_loss loss = 0.0 # Bootstrap from each image not entire batch for i in range(batch_size): loss += _bootstrap_xentropy_single( input=torch.unsqueeze(input[i], 0), target=torch.unsqueeze(target[i], 0), K=K, weight=weight, size_average=size_average, ) return loss / float(batch_size)