Source code for fusionlab.encoders.vgg.vgg

import torch
import torch.nn as nn
from fusionlab.layers import ConvND, MaxPool

# Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py
[docs] class VGG16(nn.Module): def __init__(self, cin=3, spatial_dims=2): super().__init__() ksize = 3 self.features = nn.Sequential( ConvND(spatial_dims, cin, 64, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 64, 64, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 64, 128, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 128, 128, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 128, 256, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 256, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), )
[docs] def forward(self, x): return self.features(x)
[docs] class VGG19(nn.Module): def __init__(self, cin=3, spatial_dims=2): super().__init__() ksize = 3 self.features = nn.Sequential( ConvND(spatial_dims, cin, 64, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 64, 64, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 64, 128, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 128, 128, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 128, 256, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 256, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 512, 512, ksize, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=2, stride=2), )
[docs] def forward(self, x): return self.features(x)
if __name__ == '__main__': # VGG16 inputs = torch.normal(0, 1, (1, 3, 224)) output = VGG16(cin=3, spatial_dims=1)(inputs) shape = list(output.shape) assert shape[2:] == [7] # VGG19 inputs = torch.normal(0, 1, (1, 3, 224)) output = VGG19(cin=3, spatial_dims=1)(inputs) shape = list(output.shape) assert shape[2:] == [7] # VGG16 inputs = torch.normal(0, 1, (1, 3, 224, 224)) output = VGG16(cin=3, spatial_dims=2)(inputs) shape = list(output.shape) assert shape[2:] == [7, 7] # VGG19 inputs = torch.normal(0, 1, (1, 3, 224, 224)) output = VGG19(cin=3, spatial_dims=2)(inputs) shape = list(output.shape) assert shape[2:] == [7, 7]