Source code for fusionlab.layers.squeeze_excitation.se
from typing import Callable
import torch
from torch import Tensor
import torch.nn as nn
from fusionlab.layers import ConvND, AdaptiveAvgPool
[docs]
class SEModule(nn.Module):
"""
source: https://github.com/pytorch/vision/blob/main/torchvision/ops/misc.py#L224
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
Args:
input_channels (int): Number of channels in the input image
squeeze_channels (int): Number of squeeze channels
activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
"""
def __init__(
self,
input_channels: int,
squeeze_channels: int,
act_layer: Callable[..., torch.nn.Module] = torch.nn.ReLU,
scale_layer: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
spatial_dims: int = 2,
) -> None:
super().__init__()
self.avgpool = AdaptiveAvgPool(spatial_dims, 1)
self.fc1 = ConvND(spatial_dims, input_channels, squeeze_channels, kernel_size=1)
self.fc2 = ConvND(spatial_dims, squeeze_channels, input_channels, kernel_size=1)
self.act_layer = act_layer()
self.scale_layer = scale_layer()
def _scale(self, input: Tensor) -> Tensor:
scale = self.avgpool(input)
scale = self.fc1(scale)
scale = self.act_layer(scale)
scale = self.fc2(scale)
return self.scale_layer(scale)
[docs]
def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input)
return scale * input
if __name__ == '__main__':
print('SEModule')
inputs = torch.normal(0, 1, (1, 256, 16, 16))
layer = SEModule(256)
outputs = layer(inputs)
assert list(outputs.shape) == [1, 256, 16, 16]
inputs = torch.normal(0, 1, (1, 256, 16, 16, 16))
layer = SEModule(256, spatial_dims=3)
outputs = layer(inputs)
assert list(outputs.shape) == [1, 256, 16, 16, 16]