Source code for neodroidvision.regression.metric.contrastive.nlet_conv_net

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 30/06/2020
           """

from typing import List

import torch
from draugr.torch_utilities import conv2d_hw_shape, pad2d_hw_shape
from numpy import product
from torch import nn

__all__ = ["NLetConvNet"]


[docs]class NLetConvNet(nn.Module): """description"""
[docs] def __init__(self, in_size=None, output_size: int = 2): super().__init__() flat_lin_size = 8 * product( conv2d_hw_shape( pad2d_hw_shape( conv2d_hw_shape( pad2d_hw_shape( conv2d_hw_shape(pad2d_hw_shape(in_size, 1), 3), 1 ), 3, ), 1, ), 3, ) ) self.convolutions = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(1, 4, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(4), nn.Dropout2d(p=0.2), nn.ReflectionPad2d(1), nn.Conv2d(4, 8, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(8), nn.Dropout2d(p=0.2), nn.ReflectionPad2d(1), nn.Conv2d(8, 8, kernel_size=3), nn.ReLU(), nn.BatchNorm2d(8), nn.Dropout2d(p=0.2), nn.Flatten(), nn.Linear(flat_lin_size, 500), nn.ReLU(), nn.Linear(500, 500), nn.ReLU(), nn.Linear(500, output_size), )
[docs] def forward(self, *n_let) -> List[torch.Tensor]: """ :return: :rtype:""" return [self.convolutions(x) for x in n_let]