Source code for neodroidvision.classification.architectures.resnet_retrain

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Tuple

import torch
import torchvision

__author__ = "Christian Heider Nielsen"
__doc__ = r"""

           Created on 11/11/2019
           """

from draugr.torch_utilities import (
    set_all_parameter_requires_grad,
    set_first_n_parameter_requires_grad,
    trainable_parameters,
)

__all__ = ["resnet_retrain"]

from torch.nn.parameter import Parameter
from torchvision.models import ResNet


[docs]def resnet_retrain( num_classes: int, freeze_first_num: int = 6, pretrained: bool = True, resnet_factory: callable = torchvision.models.resnet18, ) -> Tuple[ResNet, List[Parameter]]: """ Args: num_classes: freeze_first_num: pretrained: resnet_factory: Returns: """ model = resnet_factory(pretrained=pretrained) if freeze_first_num == 0: set_all_parameter_requires_grad(model) elif freeze_first_num > 0: set_first_n_parameter_requires_grad(model, freeze_first_num) model.fc = torch.nn.Linear(model.fc.in_features, num_classes) return model, trainable_parameters(model)