Source code for neodroidvision.classification.architectures.other_retrain
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
           Created on 13/11/2019
           """
__all__ = ["other_retrain"]
[docs]def other_retrain(
    arch: str, model: torch.nn.Module, num_classes: int
) -> torch.nn.Module:
    """
    Inplace op but returns the model anyway
    """
    if arch.startswith("alexnet"):
        model.classifier[6] = torch.nn.Linear(
            model.classifier[6].in_features, num_classes
        )
        print(f"=> reshaped AlexNet classifier layer with: {str(model.classifier[6])}")
    elif arch.startswith("vgg"):
        model.classifier[6] = torch.nn.Linear(
            model.classifier[6].in_features, num_classes
        )
        print(f"=> reshaped VGG classifier layer with: {str(model.classifier[6])}")
    elif arch.startswith("densenet"):
        model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
        print(f"=> reshaped DenseNet classifier layer with: {str(model.classifier)}")
    elif arch.startswith("inception"):
        model.AuxLogits.fc = torch.nn.Linear(
            model.AuxLogits.fc.in_features, num_classes
        )
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    return model