Source code for neodroidvision.utilities.torch_utilities.persistence.check_pointer

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 23/03/2020
           """

import logging
from pathlib import Path
from typing import Any

import torch
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel

from neodroidvision.utilities.torch_utilities.persistence.custom_model_caching import (
    custom_cache_url,
)

__all__ = ["CheckPointer"]


[docs]class CheckPointer: """description""" _last_checkpoint_name = "last_checkpoint.txt"
[docs] def __init__( self, model: Module, optimiser: torch.optim.Optimizer = None, scheduler: torch.optim.lr_scheduler = None, save_dir: Path = Path.cwd(), save_to_disk: bool = None, logger: logging.Logger = None, ): """ :param model: :type model: :param optimiser: :type optimiser: :param scheduler: :type scheduler: :param save_dir: :type save_dir: :param save_to_disk: :type save_to_disk: :param logger: :type logger:""" self.model = model self.optimiser = optimiser self.scheduler = scheduler self.save_dir = save_dir self.save_to_disk = save_to_disk if logger is None: logger = logging.getLogger(__name__) self.logger = logger
[docs] def save(self, name, **kwargs): """ Args: self: name: **kwargs: Returns: """ if not self.save_dir: return if not self.save_to_disk: return data = {} if isinstance(self.model, DistributedDataParallel): data["model"] = self.model.module.state_dict() else: data["model"] = self.model.state_dict() if self.optimiser is not None: data["optimiser"] = self.optimiser.state_dict() if self.scheduler is not None: data["scheduler"] = self.scheduler.state_dict() data.update(kwargs) save_file = self.save_dir / f"{name}.pth" self.logger.info(f"Saving checkpoint to {save_file}") torch.save(data, save_file) self.tag_last_checkpoint(save_file)
[docs] def load(self, f: Path = None, use_latest=True): """ Args: self: f: use_latest: Returns: """ if f is None: return {} f = str(f) if (self.save_dir / self._last_checkpoint_name).exists() and use_latest: # override argument with existing checkpoint f = self.get_checkpoint_file() if f is None or f == "" or f == "None": # no checkpoint could be found self.logger.info("No checkpoint found.") return {} self.logger.info(f"Loading checkpoint from {f}") checkpoint = self._load_file(f) model = self.model if isinstance(model, DistributedDataParallel): model = self.model.module model.load_state_dict(checkpoint.pop("model")) if "optimiser" in checkpoint and self.optimiser: self.logger.info(f"Loading optimiser from {f}") self.optimiser.load_state_dict(checkpoint.pop("optimiser")) if "scheduler" in checkpoint and self.scheduler: self.logger.info(f"Loading scheduler from {f}") self.scheduler.load_state_dict(checkpoint.pop("scheduler")) # return any further checkpoint data return checkpoint
[docs] def get_checkpoint_file(self) -> str: """ Args: self: Returns: """ try: with open(str(self.save_dir / self._last_checkpoint_name), "r") as f: last_saved = f.read().strip() except IOError: # if file doesn't exist, maybe because it has just been # deleted by a separate process last_saved = "" return last_saved
[docs] def tag_last_checkpoint(self, last_filename) -> None: """ Args: self: last_filename: """ with open(str(self.save_dir / self._last_checkpoint_name), "w") as f: f.write(last_filename)
def _load_file(self, f: str) -> Any: # download url files if f.startswith("http"): # if the file is a url path, download it and cache it f = custom_cache_url(f) self.logger.info(f"url {f} cached in {f}") return torch.load(f, map_location=torch.device("cpu"))