Source code for neodroidvision.multitask.fission.skip_hourglass.compress
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 19-09-2021
"""
from typing import Tuple
import torch
from torch import nn
__all__ = ["Compress"]
[docs]class Compress(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution."""
[docs] def __init__(
self,
in_channels: int,
out_channels: int,
*,
pooling: bool = True,
activation=torch.relu,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.activation = activation
self.conv1 = nn.Conv2d(
self.in_channels,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=1,
)
self.conv2 = nn.Conv2d(
self.out_channels,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=1,
)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
[docs] def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
Returns:
"""
x = self.activation(self.conv2(self.activation(self.conv1(x))))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool