Source code for neodroidvision.utilities.torch_utilities.layers.separable_conv

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "heider"
__doc__ = r"""

           Created on 5/5/22
           """

from torch import nn

__all__ = ["SeparableConv2d"]


[docs]class SeparableConv2d(nn.Module): """description"""
[docs] def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 1, stride: int = 1, padding: int = 0, onnx_compatible: bool = False, ): """ :param in_channels: :param out_channels: :param kernel_size: :param stride: :param padding: :param onnx_compatible:""" super().__init__() ReLU = nn.ReLU if onnx_compatible else nn.ReLU6 self.conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels, stride=stride, padding=padding, ), nn.BatchNorm2d(in_channels), ReLU(), nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=1 ), )
[docs] def forward(self, x): """ :param x: :return: """ return self.conv(x)