#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "Christian Heider Nielsen"
           Created on 07/03/2020

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

__all__ = ["get_model_instance_segmentation"]

[docs]def get_model_instance_segmentation( num_categories: int, hidden_layer: int = 256 ) -> torch.nn.Module: """ Args: num_categories: hidden_layer: Returns: """ # load an instance segmentation model pre-trained pre-trained on COCO model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_categories) # now get the number of input features for the mask classifier in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels # and replace the mask predictor with a new one model.roi_heads.mask_predictor = MaskRCNNPredictor( in_features_mask, hidden_layer, num_categories ) return model