Source code for fusionlab.segmentation.segformer.segformer

from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from fusionlab.encoders import (
    MiT,
    MiTB0,
    MiTB1,
    MiTB2,
    MiTB3,
    MiTB4,
    MiTB5,
)

[docs] class MLP(nn.Module): def __init__(self, dim, embed_dim): super().__init__() self.proj = nn.Linear(dim, embed_dim)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.flatten(2).transpose(1, 2) x = self.proj(x) return x
[docs] class ConvModule(nn.Module): def __init__(self, c1, c2): super().__init__() self.conv = nn.Conv2d(c1, c2, 1, bias=False) self.bn = nn.BatchNorm2d(c2) # use SyncBN in original self.activate = nn.ReLU(True)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: return self.activate(self.bn(self.conv(x)))
[docs] class SegFormerHead(nn.Module): def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19): super().__init__() for i, dim in enumerate(dims): self.add_module(f"linear_c{i+1}", MLP(dim, embed_dim)) self.linear_fuse = ConvModule(embed_dim*4, embed_dim) self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1) self.dropout = nn.Dropout2d(0.1)
[docs] def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: B, _, H, W = features[0].shape outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])] for i, feature in enumerate(features[1:]): cf = eval(f"self.linear_c{i+2}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:]) outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False)) seg = self.linear_fuse(torch.cat(outs[::-1], dim=1)) seg = self.linear_pred(self.dropout(seg)) return seg
[docs] class SegFormer(nn.Module): """ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers<https://arxiv.org/abs/2105.15203> source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/segformer.py Args: num_classes (int): number of classes to segment mit_encoder_type (str): type of MiT encoder, one of ['B0', 'B1', 'B2', 'B3', 'B4', 'B5'] """ def __init__( self, num_classes: int = 6, mit_encoder_type: str = 'B0' ): super().__init__() self.encoder: MiT = eval(f'MiT{mit_encoder_type}')() embed_dim = self.encoder.channels[-1] self.decode_head = SegFormerHead( self.encoder.channels, embed_dim, num_classes, )
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: _, features = self.encoder(inputs, return_features=True) x = self.decode_head(features) # 4x reduction in image size x = F.interpolate(x, size=inputs.shape[2:], mode='bilinear', align_corners=False) return x
if __name__ == '__main__': model = SegFormer(num_classes=6) x = torch.randn(1, 3, 128, 128) outputs = model(x) print(outputs.shape)