Source code for neodroidvision.classification.mechanims.attention.self.self_attention

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
from torch import nn

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
           """

from torch.nn.init import xavier_uniform_

from neodroidvision.classification.mechanims.attention.self.spectral_norm import (
    spectral_norm_conv2d,
)

__all__ = ["init_weights", "SelfAttentionModule"]


[docs]def init_weights(m): """ Args: m: """ if type(m) == nn.Linear or type(m) == nn.Conv2d: xavier_uniform_(m.weight) m.bias.data.fill_(0.0)
[docs]class SelfAttentionModule(nn.Module): """Self attention Layer"""
[docs] def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.spectral_norm_conv1x1_theta = spectral_norm_conv2d( in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1, stride=1, padding=0, ) self.spectral_norm_conv1x1_phi = spectral_norm_conv2d( in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1, stride=1, padding=0, ) self.spectral_norm_conv1x1_g = spectral_norm_conv2d( in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1, stride=1, padding=0, ) self.spectral_norm_conv1x1_attn = spectral_norm_conv2d( in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1, stride=1, padding=0, ) self.maxpool = nn.MaxPool2d(2, stride=2, padding=0) self.softmax = nn.Softmax( dim=-1 ) # TODO: use log_softmax?, Check dim maybe it should be 1 self.sigma = nn.Parameter(torch.zeros(1), requires_grad=True)
[docs] def forward(self, x): """ inputs : x : input feature maps(B X C X W X H) returns : out : self attention value + input feature attention: B X N X N (N is Width*Height)""" _, ch, h, w = x.size() # Theta path theta = self.spectral_norm_conv1x1_theta(x) theta = theta.reshape(-1, ch // 8, h * w) # Phi path phi = self.spectral_norm_conv1x1_phi(x) phi = self.maxpool(phi) phi = phi.reshape(-1, ch // 8, h * w // 4) # Attn map attn = torch.bmm(theta.permute(0, 2, 1), phi) attn = self.softmax(attn) # g path g = self.spectral_norm_conv1x1_g(x) g = self.maxpool(g) g = g.reshape(-1, ch // 2, h * w // 4) # Attn_g attn_g = torch.bmm(g, attn.permute(0, 2, 1)) attn_g = attn_g.reshape(-1, ch // 2, h, w) attn_g = self.spectral_norm_conv1x1_attn(attn_g) # Out out = x + self.sigma * attn_g return out