Source code for neodroidvision.classification.loss_functions.ranking.pairwise_ranking_loss

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 30/06/2020
           """

import torch
from torch.nn import functional

__all__ = ["PairwiseRankingLoss"]


[docs]class PairwiseRankingLoss(torch.nn.Module): """ Contrastive loss function. Neighbours(same category) are pulled together and non-neighbors are pushed apart From http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf"""
[docs] def __init__(self, margin: float = 1.0): super().__init__() self._margin = margin
[docs] def forward( self, anchor: torch.Tensor, other: torch.Tensor, is_diff: torch.Tensor ) -> torch.Tensor: """ if the is_different is 0 the examples are of the same category and thus gradient point is the direction to minimize distance between the examples. if the is_different is 1 it should minimize the residual of margin-distance to spread samples provided apart in the latent space. # Reduction is mean :param anchor: :type anchor: :param other: :type other: :param is_diff: :type is_diff: :return: :rtype:""" # assert s1.is_contiguous() # assert s2.is_contiguous() # assert is_same.is_contiguous() euclidean_distance = functional.pairwise_distance(anchor, other) return torch.mean( (1 - is_diff) * euclidean_distance**2 + is_diff * torch.clamp(self._margin - euclidean_distance, min=0.0) ** 2 ) # if distance is larger than margin(desirable), clip to 0 loss.