Source code for neodroidvision.classification.loss_functions.confidence_loss

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
           
A collection of loss functions for pytorch and torchvision
"""

import copy

import torch
from torch import nn
from torch.nn import functional


[docs]class ConfidenceLoss(nn.Module): """ An almost copy paste of https://github.com/aivclab/confidence_classification_loss_pytorch The function uses the nn.Module internal state "training" to determine when to update lambda. You must call .eval() and .train() yourself Implements "Learning Confidence for Out-of-Distribution Detection in Neural Networks" https://arxiv.org/pdf/1802.04865.pdf A Softmax-cross-entropy classification loss, which provides an additional "confidence" output, which signals whether the softmax output is confident. """
[docs] def __init__(self, hint_budget=0.3, lmbda=0.1): """ Parameters --------------------- hint_budget : float Refer to the paper. lambda : float Refer to the paper. """ super().__init__() self.hint_budget = hint_budget self._initial_lmbda = copy.copy(lmbda) self.lmbda = lmbda
[docs] def reset(self): """ Set lambda to initial lambda """ self.lmbda = copy.copy(self._initial_lmbda)
def _update_lmbda(self, conf_loss): """ Updates current lambda. Ensure .eval() has been called if lambda should not update """ if self.training: self.lmbda = ( self.lmbda / 1.01 if conf_loss.item() < self.hint_budget else self.lmbda / 0.99 )
[docs] @classmethod def predict(cls, input): """ Compute prediction pseudo probabilities and confidence from logits Parameters ---------- input : torch.tensor (BxK) float where K=<number_of_classes> + 1 Classification logits + confidence logit Returns ------- (prediction, conf) prediction softmax probabilities and confidence level Example ------- (prediction, conf) = ConfidenceLoss.predict(logit_tensor) """ input_pred = input[..., :-1] input_conf = input[..., -1] pred = torch.softmax(input_pred, dim=-1) conf = torch.sigmoid(input_conf) return pred, conf
[docs] def forward(self, input, target): """ Compute loss Parameters ---------- input : torch.tensor (BxK) float where K=<number_of_classes> + 1 Classification logits + confidence logit target : torch.tensor (Bx1) long Target class Returns ------- loss value """ pred_orig, conf_orig = self.predict(input) target_1hot = functional.one_hot(target, pred_orig.shape[-1]) # Clamp eps = 1e-12 pred_orig = torch.clamp(pred_orig, eps, 1 - eps) conf_orig = torch.clamp(conf_orig, eps, 1 - eps) # Randomly set half of predictions to 100% confidence b = torch.empty_like(conf_orig).uniform_(0, 1).round() conf_new = conf_orig * b + 1 - b pred_new = pred_orig * conf_new[:, None] + target_1hot * (1 - conf_new[:, None]) pred_new = pred_new.log() pred_loss = functional.nll_loss(pred_new, target) conf_loss = -torch.log(conf_new).mean() tot_loss = pred_loss + self.lmbda * conf_loss self._update_lmbda(conf_loss) return tot_loss