Source code for neodroidvision.classification.architectures.other_retrain
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 13/11/2019
"""
__all__ = ["other_retrain"]
[docs]def other_retrain(
arch: str, model: torch.nn.Module, num_classes: int
) -> torch.nn.Module:
"""
Inplace op but returns the model anyway
"""
if arch.startswith("alexnet"):
model.classifier[6] = torch.nn.Linear(
model.classifier[6].in_features, num_classes
)
print(f"=> reshaped AlexNet classifier layer with: {str(model.classifier[6])}")
elif arch.startswith("vgg"):
model.classifier[6] = torch.nn.Linear(
model.classifier[6].in_features, num_classes
)
print(f"=> reshaped VGG classifier layer with: {str(model.classifier[6])}")
elif arch.startswith("densenet"):
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
print(f"=> reshaped DenseNet classifier layer with: {str(model.classifier)}")
elif arch.startswith("inception"):
model.AuxLogits.fc = torch.nn.Linear(
model.AuxLogits.fc.in_features, num_classes
)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
return model