Source code for neodroidvision.regression.vae.architectures.vae

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from torch import nn
from torch.nn.init import kaiming_normal_

__author__ = "Christian Heider Nielsen"
__doc__ = r"""
           """

from draugr.torch_utilities import VariationalAutoEncoder

__all__ = ["VAE"]


[docs]class VAE(VariationalAutoEncoder): """description"""
[docs] def __init__(self, latent_size=10): super().__init__() self._latent_size = latent_size
[docs] class View(nn.Module): """description""" def __init__(self, size): super().__init__() self.size = size
[docs] def forward(self, tensor): """ Args: tensor: Returns: """ return tensor.reshape(self.size)
[docs] @staticmethod def kaiming_init(m): """ Args: m: """ if isinstance(m, (nn.Linear, nn.Conv2d)): kaiming_normal_(m.weight) if m.bias is not None: m.bias.data.fill_(0) elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): m.weight.data.fill_(1) if m.bias is not None: m.bias.data.fill_(0)
[docs] @staticmethod def normal_init(m, mean, std): """ :param m: :type m: :param mean: :type mean: :param std: :type std: """ if isinstance(m, (nn.Linear, nn.Conv2d)): m.weight.data.normal_(mean, std) if m.bias.data is not None: m.bias.data.zero_() elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)): m.weight.data.fill_(1) if m.bias.data is not None: m.bias.data.zero_()
[docs] def weight_init(self): """description""" for m in self.modules(): self.kaiming_init(m)