Source code for fusionlab.layers.selfattention.selfattention

import torch
import torch.nn as nn
from fusionlab.layers import Rearrange


[docs] class SelfAttention(nn.Module): """ A self-attention block, based on: "Dosovitskiy et al., An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>" source code: https://github.com/Project-MONAI/MONAI/blob/main/monai/networks/blocks/selfattention.py#L22 """
[docs] def __init__( self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, qkv_bias: bool = False, save_attn: bool = False, ) -> None: """ Args: hidden_size (int): dimension of hidden layer. num_heads (int): number of attention heads. dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0. qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") self.num_heads = num_heads self.out_proj = nn.Linear(hidden_size, hidden_size) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) # b: batch size, h: num_patches, l: num_heads, d: head_dim self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads) self.out_rearrange = Rearrange("b h l d -> b l (h d)") self.drop_output = nn.Dropout(dropout_rate) self.drop_weights = nn.Dropout(dropout_rate) self.head_dim = hidden_size // num_heads self.scale = self.head_dim**-0.5 self.save_attn = save_attn self.att_mat = torch.Tensor()
[docs] def forward(self, x): qkv = self.input_rearrange(self.qkv(x)) q, k, v = qkv[0], qkv[1], qkv[2] # (b, l, h, d) att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) # (b, l, h, h) if self.save_attn: self.att_mat = att_mat.detach() att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) # (b, l, h, d) x = self.out_rearrange(x) x = self.out_proj(x) x = self.drop_output(x) return x
[docs] class SRAttention(nn.Module): """ Spatial Reduction Attention (SR-Attention) block, based on "Wang et al., Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions <https://arxiv.org/abs/2102.12122v2>" source code: https://github.com/sithu31296/semantic-segmentation/blob/main/semseg/models/backbones/mit.py """
[docs] def __init__(self, dim, head, sr_ratio): """ Args: dim (int): input dimension head (int): number of attention heads sr_ratio (int): spatial reduction ratio """ super().__init__() self.head = head self.sr_ratio = sr_ratio self.scale = (dim // head) ** -0.5 self.q = nn.Linear(dim, dim) self.kv = nn.Linear(dim, dim*2) self.proj = nn.Linear(dim, dim) if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio) self.norm = nn.LayerNorm(dim)
[docs] def forward(self, x: torch.Tensor, H, W) -> torch.Tensor: B, N, C = x.shape q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3) if self.sr_ratio > 1: x = x.permute(0, 2, 1).reshape(B, C, H, W) x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) x = self.norm(x) k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x