Source code for neodroidvision.regression.metric.contrastive.pair_ranking

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 27/07/2020
           """

__all__ = ["PairRankingSiamese"]

from typing import Tuple, Union

import torch
from draugr.torch_utilities import conv2d_hw_shape, pad2d_hw_shape
from numpy import product
from torch import nn


[docs]class PairRankingSiamese(nn.Module): """description"""
[docs] def __init__( self, in_size: Union[int, Tuple[int, int]] = (105, 105), output_size: int = 1, input_channels: int = 1, ): super().__init__() flat_lin_size = 32 * product( conv2d_hw_shape( pad2d_hw_shape( conv2d_hw_shape( pad2d_hw_shape( conv2d_hw_shape(pad2d_hw_shape(in_size, 1), 3), 1 ), 3, ), 1, ), 3, ) ) self.concat_merge = True if self.concat_merge: flat_lin_size *= 2 self.mapping = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(input_channels, 64, 3), nn.ReLU(inplace=True), nn.BatchNorm2d(64), nn.Dropout2d(p=0.2), nn.ReflectionPad2d(1), nn.Conv2d(64, 64, 3), nn.ReLU(inplace=True), nn.BatchNorm2d(64), nn.Dropout2d(p=0.2), nn.ReflectionPad2d(1), nn.Conv2d(64, 32, 3), nn.ReLU(inplace=True), nn.BatchNorm2d(32), nn.Dropout2d(p=0.2), nn.Flatten(), ) self.head = nn.Sequential( nn.Linear(flat_lin_size, 512), nn.ReLU(), nn.Linear(512, output_size), nn.Sigmoid(), )
[docs] def forward(self, anchor: torch.Tensor, other: torch.Tensor) -> torch.Tensor: """ Args: anchor: other: Returns: """ if self.concat_merge: dist = torch.cat((self.mapping(anchor), self.mapping(other)), 1) else: dist = torch.abs(self.mapping(anchor) - self.mapping(other)) return self.head(dist)