Source code for neodroidvision.segmentation.evaluation.focal_loss
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
import torch
from torch import nn
from torch.autograd import Variable
__all__ = ["FocalLoss"]
[docs]class FocalLoss(nn.Module):
r"""
This criterion is a implementation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The loss_functions are averaged across observations for each mini batch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): size_average(bool): By default, the loss_functions are averaged over
observations for
each mini batch.
However, if the field size_average is set to False, the loss_functions are
instead summed for each mini batch."""
[docs] def __init__(
self, class_num, alpha=None, gamma: float = 2.0, size_average: bool = True
):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
[docs] def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Args:
inputs:
targets:
Returns:
"""
N = inputs.size(0)
C = inputs.size(1)
P = torch.softmax(
inputs, 0
) # TODO: use log_softmax? Check dim maybe it should be 1
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.reshape(-1, 1)
class_mask.scatter_(1, ids.data, 1.0)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P * class_mask).sum(1).reshape(-1, 1)
log_p = probs.log()
batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
if __name__ == "__main__":
alpha = torch.rand(21, 1)
focal_loss_func = FocalLoss(class_num=5, gamma=0)
cross_entropy_func = nn.CrossEntropyLoss()
N = 4
C = 5
inputs = torch.rand(N, C)
targets = torch.LongTensor(N).random_(C)
inputs_fl = Variable(inputs.clone(), requires_grad=True)
targets_fl = Variable(targets.clone())
inputs_ce = Variable(inputs.clone(), requires_grad=True)
targets_ce = Variable(targets.clone())
print("----inputs----")
print(inputs)
print("---target-----")
print(targets)
fl_loss = focal_loss_func(inputs_fl, targets_fl)
ce_loss = cross_entropy_func(inputs_ce, targets_ce)
print(f"ce = {ce_loss.item()}, fl ={fl_loss.item()}")
fl_loss.backward()
ce_loss.backward()
print(inputs_ce.grad.data)