Source code for neodroidvision.classification.mechanims.attention.self.spectral_norm
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
"""
from torch import nn
from torch.nn.utils import spectral_norm
__all__ = ["spectral_norm_conv2d", "spectral_norm_linear", "spectral_norm_embedding"]
[docs]def spectral_norm_conv2d(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
"""
Args:
in_channels:
out_channels:
kernel_size:
stride:
padding:
dilation:
groups:
bias:
Returns:
"""
return spectral_norm(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
)
[docs]def spectral_norm_linear(in_features, out_features):
"""
Args:
in_features:
out_features:
Returns:
"""
return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))
[docs]def spectral_norm_embedding(num_embeddings, embedding_dim):
"""
Args:
num_embeddings:
embedding_dim:
Returns:
"""
return spectral_norm(
nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
)