Source code for neodroidvision.classification.architectures.resnet_retrain
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import List, Tuple
import torch
import torchvision
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 11/11/2019
"""
from draugr.torch_utilities import (
set_all_parameter_requires_grad,
set_first_n_parameter_requires_grad,
trainable_parameters,
)
__all__ = ["resnet_retrain"]
from torch.nn.parameter import Parameter
from torchvision.models import ResNet
[docs]def resnet_retrain(
num_classes: int,
freeze_first_num: int = 6,
pretrained: bool = True,
resnet_factory: callable = torchvision.models.resnet18,
) -> Tuple[ResNet, List[Parameter]]:
"""
Args:
num_classes:
freeze_first_num:
pretrained:
resnet_factory:
Returns:
"""
model = resnet_factory(pretrained=pretrained)
if freeze_first_num == 0:
set_all_parameter_requires_grad(model)
elif freeze_first_num > 0:
set_first_n_parameter_requires_grad(model, freeze_first_num)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
return model, trainable_parameters(model)