'''
Ref: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
Ref: https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py
'''
from einops import rearrange
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.ops import StochasticDepth
from fusionlab.layers import ConvND
[docs]
class Block(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, spatial_dims=2):
super().__init__()
self.dwconv = ConvND(spatial_dims, dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = StochasticDepth(drop_path, "row") if drop_path > 0. else nn.Identity()
[docs]
def forward(self, x):
input = x
x = self.dwconv(x)
x = rearrange(x, 'N C ... -> N ... C')
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = rearrange(x, 'N ... C -> N C ...')
x = input + self.drop_path(x)
return x
[docs]
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.,
layer_scale_init_value=1e-6,
spatial_dims=2,
):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
ConvND(spatial_dims, in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
ConvND(spatial_dims, dims[i], dims[i+1], kernel_size=2, stride=2)
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
spatial_dims=spatial_dims) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
nn.init.trunc_normal_(m.weight, mean=0.0, std=1.0, a=- 2.0, b=2.0)
nn.init.constant_(m.bias, 0)
[docs]
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return x
[docs]
def forward(self, x):
x = self.forward_features(x)
return x
[docs]
class LayerNorm(nn.Module):
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
[docs]
def forward(self, x):
if self.data_format == "channels_first":
# to channel last
x = rearrange(x, 'N C ... -> N ... C')
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = rearrange(x, 'N ... C -> N C ...')
return x
elif self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
[docs]
class ConvNeXtTiny(ConvNeXt):
def __init__(self, cin=3, spatial_dims=2):
super().__init__(depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
in_chans=cin,
spatial_dims=spatial_dims)
[docs]
class ConvNeXtSmall(ConvNeXt):
def __init__(self, cin=3, spatial_dims=2):
super().__init__(depths=[3, 3, 27, 3],
dims=[96, 192, 384, 768],
in_chans=cin,
spatial_dims=spatial_dims)
[docs]
class ConvNeXtBase(ConvNeXt):
def __init__(self, cin=3, spatial_dims=2):
super().__init__(depths=[3, 3, 27, 3],
dims=[128, 256, 512, 1024],
in_chans=cin,
spatial_dims=spatial_dims)
[docs]
class ConvNeXtLarge(ConvNeXt):
def __init__(self, cin=3, spatial_dims=2):
super().__init__(depths=[3, 3, 27, 3],
dims=[192, 384, 768, 1536],
in_chans=cin,
spatial_dims=spatial_dims)
[docs]
class ConvNeXtXLarge(ConvNeXt):
def __init__(self, cin=3, spatial_dims=2):
super().__init__(depths=[3, 3, 27, 3],
dims=[256, 512, 1024, 2048],
in_chans=cin,
spatial_dims=spatial_dims)
if __name__ == '__main__':
print('ConvNeXt')
for spatial_dims in [1, 2, 3]:
for i, convnext_type in enumerate(['Tiny', 'Small', 'Base', 'Large', 'XLarge']):
# print()
model = eval(f'ConvNeXt{convnext_type}')(spatial_dims=spatial_dims)
input_size = tuple([1, 3] + [64] * spatial_dims)
inputs = torch.randn(input_size)
outputs = model(inputs)
target_ch = [768, 768, 1024, 1536, 2048]
assert outputs.shape == torch.Size([1, target_ch[i]] + [2] * spatial_dims)
target_params = [27818592, 49453152, 87564416, 196227264, 348143872]
if spatial_dims == 2:
import torchinfo
log = torchinfo.summary(model, input_size=input_size, verbose=0)
assert log.total_params == target_params[i]