Source code for neodroidvision.classification.architectures.torus.conv
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 05-03-2021
"""
__all__ = ["TorusConv2d"]
import torch
[docs]class TorusConv2d(torch.nn.Module):
"""description"""
[docs] def __init__(self, input_dim: int, output_dim: int, kernel_size, bn: bool):
super().__init__()
self.edge_size = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv = torch.nn.Conv2d(input_dim, output_dim, kernel_size=kernel_size)
self.bn = torch.nn.BatchNorm2d(output_dim) if bn else None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Returns:
"""
h = torch.cat(
[x[:, :, :, -self.edge_size[1] :], x, x[:, :, :, : self.edge_size[1]]],
dim=3,
)
h = torch.cat(
[h[:, :, -self.edge_size[0] :], h, h[:, :, : self.edge_size[0]]], dim=2
)
h = self.conv(h)
if self.bn is not None:
h = self.bn(h)
return h