#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 19-09-2021
"""
import torch
from torch import nn
from torch.nn.functional import interpolate
from torch.nn.init import xavier_uniform_
from neodroidvision.classification.mechanims.attention.self.self_attention import (
init_weights,
)
from neodroidvision.classification.mechanims.attention.self.spectral_norm import (
spectral_norm_conv2d,
spectral_norm_embedding,
spectral_norm_linear,
)
from neodroidvision.mixed.architectures.self_attention_network.self_attention_network import (
SelfAttentionModule,
)
[docs]class ConditionalBatchNorm2d(nn.Module):
"""https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775"""
[docs] def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# self.embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, :num_features].fill_(1.0) # Initialize scale to 1
self.embed.weight.data[:, num_features:].zero_() # Initialize bias at 0
[docs] def forward(self, x, y):
"""
Args:
x:
y:
Returns:
"""
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
return gamma.reshape(-1, self.num_features, 1, 1) * out + beta.reshape(
-1, self.num_features, 1, 1
)
[docs]class GenBlock(nn.Module):
"""description"""
[docs] def __init__(self, in_channels, out_channels, num_classes):
super(GenBlock, self).__init__()
self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)
self.relu = nn.ReLU(inplace=True)
self.spectral_norm_conv2d1 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)
self.spectral_norm_conv2d2 = spectral_norm_conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.spectral_norm_conv2d0 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
)
[docs] def forward(self, x, labels):
"""
Args:
x:
labels:
Returns:
"""
x0 = x
x = self.cond_bn1(x, labels)
x = self.relu(x)
x = interpolate(x, scale_factor=2, mode="nearest") # upsample
x = self.spectral_norm_conv2d1(x)
x = self.cond_bn2(x, labels)
x = self.relu(x)
x = self.spectral_norm_conv2d2(x)
x0 = interpolate(x0, scale_factor=2, mode="nearest") # upsample
x0 = self.spectral_norm_conv2d0(x0)
return x + x0
[docs]class Generator(nn.Module):
"""Generator."""
[docs] def __init__(self, z_dim, g_conv_dim, num_classes):
super(Generator, self).__init__()
self.z_dim = z_dim
self.g_conv_dim = g_conv_dim
self.spectral_norm_linear0 = spectral_norm_linear(
in_features=z_dim, out_features=g_conv_dim * 16 * 4 * 4
)
self.block1 = GenBlock(g_conv_dim * 16, g_conv_dim * 16, num_classes)
self.block2 = GenBlock(g_conv_dim * 16, g_conv_dim * 8, num_classes)
self.block3 = GenBlock(g_conv_dim * 8, g_conv_dim * 4, num_classes)
self.self_attn = SelfAttentionModule(g_conv_dim * 4)
self.block4 = GenBlock(g_conv_dim * 4, g_conv_dim * 2, num_classes)
self.block5 = GenBlock(g_conv_dim * 2, g_conv_dim, num_classes)
self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)
self.relu = nn.ReLU(inplace=True)
self.spectral_norm_conv2d1 = spectral_norm_conv2d(
in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1
)
self.tanh = nn.Tanh()
self.apply(init_weights)
[docs] def forward(self, z, labels):
"""
Args:
z:
labels:
Returns:
"""
# n x z_dim
act0 = self.spectral_norm_linear0(z) # n x g_conv_dim*16*4*4
act0 = act0.reshape(-1, self.g_conv_dim * 16, 4, 4) # n x g_conv_dim*16 x 4 x 4
act1 = self.block1(act0, labels) # n x g_conv_dim*16 x 8 x 8
act2 = self.block2(act1, labels) # n x g_conv_dim*8 x 16 x 16
act3 = self.block3(act2, labels) # n x g_conv_dim*4 x 32 x 32
act3 = self.self_attn(act3) # n x g_conv_dim*4 x 32 x 32
act4 = self.block4(act3, labels) # n x g_conv_dim*2 x 64 x 64
act5 = self.block5(act4, labels) # n x g_conv_dim x 128 x 128
act5 = self.bn(act5) # n x g_conv_dim x 128 x 128
act5 = self.relu(act5) # n x g_conv_dim x 128 x 128
act6 = self.spectral_norm_conv2d1(act5) # n x 3 x 128 x 128
return self.tanh(act6) # n x 3 x 128 x 128
[docs]class DiscriminatorOptBlock(nn.Module):
"""description"""
[docs] def __init__(self, in_channels, out_channels):
super().__init__()
self.spectral_norm_conv2d1 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.relu = nn.ReLU(inplace=True)
self.spectral_norm_conv2d2 = spectral_norm_conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.down_sample = nn.AvgPool2d(2)
self.spectral_norm_conv2d0 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
)
[docs] def forward(self, x):
"""
Args:
x:
Returns:
"""
x0 = x
x = self.spectral_norm_conv2d1(x)
x = self.relu(x)
x = self.spectral_norm_conv2d2(x)
x = self.down_sample(x)
x0 = self.down_sample(x0)
x0 = self.spectral_norm_conv2d0(x0)
return x + x0
[docs]class DiscriminatorBlock(nn.Module):
"""description"""
[docs] def __init__(self, in_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.spectral_norm_conv2d1 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.spectral_norm_conv2d2 = spectral_norm_conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.down_sample = nn.AvgPool2d(2)
self.channel_mismatch = False
if in_channels != out_channels:
self.channel_mismatch = True
self.spectral_norm_conv2d0 = spectral_norm_conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
)
[docs] def forward(self, x, down_sample: bool = True):
"""
Args:
x:
down_sample:
Returns:
"""
x0 = x
x = self.relu(x)
x = self.spectral_norm_conv2d1(x)
x = self.relu(x)
x = self.spectral_norm_conv2d2(x)
if down_sample:
x = self.down_sample(x)
if down_sample or self.channel_mismatch:
x0 = self.spectral_norm_conv2d0(x0)
if down_sample:
x0 = self.down_sample(x0)
return x + x0
[docs]class Discriminator(nn.Module):
"""Discriminator."""
[docs] def __init__(self, d_conv_dim, num_classes):
super(Discriminator, self).__init__()
self.d_conv_dim = d_conv_dim
self.opt_block1 = DiscriminatorOptBlock(3, d_conv_dim)
self.block1 = DiscriminatorBlock(d_conv_dim, d_conv_dim * 2)
self.self_attn = SelfAttentionModule(d_conv_dim * 2)
self.block2 = DiscriminatorBlock(d_conv_dim * 2, d_conv_dim * 4)
self.block3 = DiscriminatorBlock(d_conv_dim * 4, d_conv_dim * 8)
self.block4 = DiscriminatorBlock(d_conv_dim * 8, d_conv_dim * 16)
self.block5 = DiscriminatorBlock(d_conv_dim * 16, d_conv_dim * 16)
self.relu = nn.ReLU(inplace=True)
self.spectral_norm_linear1 = spectral_norm_linear(
in_features=d_conv_dim * 16, out_features=1
)
self.spectral_norm_embedding1 = spectral_norm_embedding(
num_classes, d_conv_dim * 16
)
self.apply(init_weights)
xavier_uniform_(self.spectral_norm_embedding1.weight)
[docs] def forward(self, x, labels):
"""
Args:
x:
labels:
Returns:
"""
# n x 3 x 128 x 128
h0 = self.opt_block1(x) # n x d_conv_dim x 64 x 64
h1 = self.block1(h0) # n x d_conv_dim*2 x 32 x 32
h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32
h2 = self.block2(h1) # n x d_conv_dim*4 x 16 x 16
h3 = self.block3(h2) # n x d_conv_dim*8 x 8 x 8
h4 = self.block4(h3) # n x d_conv_dim*16 x 4 x 4
h5 = self.block5(h4, downsample=False) # n x d_conv_dim*16 x 4 x 4
h5 = self.relu(h5) # n x d_conv_dim*16 x 4 x 4
h6 = torch.sum(h5, dim=[2, 3]) # n x d_conv_dim*16
output1 = torch.squeeze(self.spectral_norm_linear1(h6)) # n x 1
# Projection
h_labels = self.spectral_norm_embedding1(labels) # n x d_conv_dim*16
proj = torch.mul(h6, h_labels) # n x d_conv_dim*16
output2 = torch.sum(proj, dim=[1]) # n x 1
return output1 + output2 # n x 1