Source code for modulation.classification.procedure
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 02-12-2020
"""
import torch
from draugr.numpy_utilities import SplitEnum
from draugr.torch_utilities import (
TorchEvalSession,
TorchTrainSession,
global_torch_device,
)
from draugr.visualisation import progress_bar
from draugr.writers import Writer
from torch.nn import Module
from torch.nn.functional import nll_loss
from torch.utils.data import DataLoader
__all__ = ['single_epoch_fitting', 'single_epoch_evaluation']
[docs]def single_epoch_fitting(
model: torch.nn.Module,
optimiser,
train_loader_,
*,
epoch: int = None,
writer: Writer = None,
device_: torch.device = global_torch_device(),
) -> None:
"""
:param model:
:type model:
:param optimiser:
:type optimiser:
:param train_loader_:
:type train_loader_:
:param epoch:
:type epoch:
:param writer:
:type writer:
:param device_:
:type device_:
"""
accum_loss = 0
num_batches = len(train_loader_)
with TorchTrainSession(model):
for batch_idx, (data, target) in progress_bar(
enumerate(train_loader_), description = "train batch #", total = num_batches
):
loss = nll_loss(
model(data.to(device_)).squeeze(), target.to(device_)
) # negative log-likelihood for a tensor of size (batch x 1 x n_output)
optimiser.zero_grad()
loss.backward()
optimiser.step()
accum_loss += loss.item()
if writer:
writer.scalar("loss", accum_loss / num_batches, epoch)
[docs]def single_epoch_evaluation(
model: Module,
evaluation_loader: DataLoader,
subset: SplitEnum,
*,
epoch: int = None,
writer: Writer = None,
device: torch.device = global_torch_device(),
) -> float:
"""
:param model:
:type model:
:param evaluation_loader:
:type evaluation_loader:
:param subset:
:type subset:
:param epoch:
:type epoch:
:param writer:
:type writer:
:param device:
:type device:
:return:
:rtype:
"""
correct = 0
num_batches = len(evaluation_loader)
with TorchEvalSession(model):
for data, target in progress_bar(
evaluation_loader, description = f"{subset} batch #", total = num_batches
):
correct += (
model(data.to(device))
.argmax(dim = -1)
.squeeze()
.eq(target.to(device))
.sum()
.item()
)
acc = correct / len(evaluation_loader.dataset)
if writer:
writer.scalar(f"{subset}_accuracy", acc, epoch)
return acc