Source code for fusionlab.classification.vgg

# VGG Classifier
import torch
from torch import nn
from fusionlab.classification.base import CNNClassificationModel
from fusionlab.encoders import VGG16, VGG19
from fusionlab.layers import AdaptiveAvgPool

[docs] class VGG16Classifier(CNNClassificationModel): def __init__(self, cin, num_cls, spatial_dims=2): super().__init__() self.num_cls = num_cls self.encoder = VGG16(cin, spatial_dims) # Create VGG16 instance self.globalpooling = AdaptiveAvgPool(spatial_dims, 1) self.head = nn.Linear(512, num_cls)
[docs] class VGG19Classifier(CNNClassificationModel): def __init__(self, cin, num_cls, spatial_dims=2): super().__init__() self.num_cls = num_cls self.encoder = VGG19(cin, spatial_dims) # Create VGG16 instance self.globalpooling = AdaptiveAvgPool(spatial_dims, 1) self.head = nn.Linear(512, num_cls)
if __name__ == '__main__': inputs = torch.randn(1, 3, 224) # create random input tensor model = VGG16Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance outputs = model(inputs) # pass input through model assert list(outputs.shape) == [1, 2] # check output shape is correct inputs = torch.randn(1, 3, 224) # create random input tensor model = VGG19Classifier(cin=3, num_cls=2, spatial_dims=1) # create model instance outputs = model(inputs) # pass input through model assert list(outputs.shape) == [1, 2] # check output shape is correct