Source code for fusionlab.encoders.vit.vit

from typing import Sequence, Union
import torch
import torch.nn as nn

from fusionlab.layers import (
    PatchEmbedding,
    SelfAttention,
)    

[docs] class MLPBlock(nn.Module): """ A multi-layer perceptron block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" """
[docs] def __init__( self, hidden_size: int, mlp_dim: int, dropout_rate: float = 0.0, act: nn.Module = nn.GELU, ) -> None: """ Args: hidden_size: dimension of hidden layer. mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used. dropout_rate: faction of the input units to drop. act: activation type and arguments. Defaults to nn.GELU """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") mlp_dim = mlp_dim or hidden_size self.linear1 = nn.Linear(hidden_size, mlp_dim) self.linear2 = nn.Linear(mlp_dim, hidden_size) self.act = act() self.drop1 = nn.Dropout(dropout_rate) self.drop2 = nn.Dropout(dropout_rate)
[docs] def forward(self, x): x = self.act(self.linear1(x)) x = self.drop1(x) x = self.linear2(x) x = self.drop2(x) return x
[docs] class TransformerBlock(nn.Module): """ A transformer block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" """
[docs] def __init__( self, hidden_size: int, mlp_dim: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: hidden_size (int): dimension of hidden layer. mlp_dim (int): dimension of feedforward layer. num_heads (int): number of attention heads. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) self.norm1 = nn.LayerNorm(hidden_size) self.attn = SelfAttention(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn) self.norm2 = nn.LayerNorm(hidden_size)
[docs] def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x
[docs] class ViT(nn.Module): """ Vision Transformer (ViT), based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" ViT supports Torchscript but only works for Pytorch after 1.8. source code: https://github.com/Project-MONAI/MONAI/blob/main/monai/networks/nets/vit.py """
[docs] def __init__( self, in_channels: int, img_size: Union[Sequence[int], int], patch_size: Union[Sequence[int], int], hidden_size: int = 768, mlp_dim: int = 3072, num_layers: int = 12, num_heads: int = 12, pos_embed: str = "conv", dropout_rate: float = 0.0, spatial_dims: int = 2, qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: in_channels (int): dimension of input channels. img_size (Union[Sequence[int], int]): dimension of input image. patch_size (Union[Sequence[int], int]): dimension of patch size. hidden_size (int, optional): dimension of hidden layer. Defaults to 768. mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072. num_layers (int, optional): number of transformer blocks. Defaults to 12. num_heads (int, optional): number of attention heads. Defaults to 12. pos_embed (str, optional): position embedding layer type. Defaults to "conv". num_classes (int, optional): number of classes if classification is used. Defaults to 2. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. spatial_dims (int, optional): number of spatial dimensions. Defaults to 3. qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False. save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False. """ super().__init__() if hidden_size % num_heads != 0: raise ValueError("hidden_size should be divisible by num_heads.") self.patch_embedding = PatchEmbedding( in_channels=in_channels, img_size=img_size, patch_size=patch_size, hidden_size=hidden_size, pos_embed_type=pos_embed, dropout_rate=dropout_rate, spatial_dims=spatial_dims, ) self.blocks = nn.ModuleList( [ TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(hidden_size)
[docs] def forward(self, x, return_features=False): x = self.patch_embedding(x) features = [] for block in self.blocks: x = block(x) features.append(x) x = self.norm(x) if return_features: return x, features else: return x
VisionTransformer = ViT if __name__ == '__main__': inputs = torch.randn(1, 3, 224, 224) model = ViT( in_channels=3, img_size=224, patch_size=16, hidden_size=768, mlp_dim=3072, num_layers=2, # num_layers=12, num_heads=12, ) outputs = model(inputs) print(outputs.shape) outputs, hidden = model(inputs, return_features=True) print(outputs.shape) [print(i.shape) for i in hidden]