Source code for fusionlab.encoders.alexnet.alexnet

import torch
import torch.nn as nn
from fusionlab.layers.factories import ConvND, MaxPool

# Official pytorch ref: https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py
[docs] class AlexNet(nn.Module): def __init__(self, cin=3, spatial_dims=2): super().__init__() self.features = nn.Sequential( ConvND(spatial_dims, cin, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=3, stride=2), ConvND(spatial_dims, 64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=3, stride=2), ConvND(spatial_dims, 192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), ConvND(spatial_dims, 256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), MaxPool(spatial_dims, kernel_size=3, stride=2), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.features(x)
if __name__ == '__main__': img_size = 224 inputs = torch.normal(0, 1, (1, 3, img_size, img_size)) output = AlexNet(3, spatial_dims=2)(inputs) print(output.shape) assert list(output.shape) == [1, 256, 6, 6] inputs = torch.normal(0, 1, (1, 3, img_size)) output = AlexNet(3, spatial_dims=1)(inputs) print(output.shape) assert list(output.shape) == [1, 256, 6] img_size = 128 inputs = torch.normal(0, 1, (1, 3, img_size, img_size, img_size)) output = AlexNet(3, spatial_dims=3)(inputs) print(output.shape) assert list(output.shape) == [1, 256, 3, 3, 3]