Source code for fusionlab.segmentation.transunet.transunet

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from collections import OrderedDict
import math

# source code: https://github.com/Beckschen/TransUNet

# TODO: extract layer to utils
[docs] class StdConv2d(nn.Conv2d):
[docs] def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-5) return F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
[docs] def conv3x3(cin, cout, stride=1, groups=1, bias=False): return StdConv2d(cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups)
[docs] def conv1x1(cin, cout, stride=1, bias=False): return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias)
[docs] class PreActBottleneck(nn.Module): """Pre-activation (v2) bottleneck block.""" def __init__(self, cin, cout=None, cmid=None, stride=1): super().__init__() cout = cout or cin cmid = cmid or cout//4 self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6) self.conv1 = conv1x1(cin, cmid, bias=False) self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6) self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!! self.gn3 = nn.GroupNorm(32, cout, eps=1e-6) self.conv3 = conv1x1(cmid, cout, bias=False) self.relu = nn.ReLU(inplace=True) self.downsample = stride != 1 or cin != cout # downsample on skip connection if self.downsample: # Projection also with pre-activation according to paper. self.downsample = conv1x1(cin, cout, stride, bias=False) self.gn_proj = nn.GroupNorm(cout, cout)
[docs] def forward(self, x): # Residual branch residual = x if self.downsample: residual = self.downsample(x) residual = self.gn_proj(residual) x = self.relu(self.gn1(self.conv1(x))) x = self.relu(self.gn2(self.conv2(x))) x = self.gn3(self.conv3(x)) x = self.relu(residual + x) return x
# TODO: Extract to encoder
[docs] class ResNetV2(nn.Module): """Implementation of Pre-activation (v2) ResNet mode.""" def __init__(self, block_units, width_factor): super().__init__() width = int(64 * width_factor) self.width = width self.root = nn.Sequential(OrderedDict([ ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)), ('gn', nn.GroupNorm(32, width, eps=1e-6)), ('relu', nn.ReLU(inplace=True)), ])) self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0) self.body = nn.Sequential(OrderedDict([ ('block1', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] + [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)], ))), ('block2', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] + [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)], ))), ('block3', nn.Sequential(OrderedDict( [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] + [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)], ))), ]))
[docs] def forward(self, x, return_features=True): features = [] b, _, in_size, _ = x.size() x = self.root(x) features.append(x) x = self.pool(x) for i in range(len(self.body)-1): x = self.body[i](x) right_size = int(in_size / 4 / (i+1)) if x.size()[2] != right_size: pad = right_size - x.size()[2] assert pad < 3 and pad > 0, f"x {x.size()} should {right_size}" feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device) feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:] else: feat = x features.append(feat) x = self.body[-1](x) if return_features: return x, features[::-1] else: return x
# TODO: extract to layers
[docs] class Attention(nn.Module): def __init__( self, num_attention_heads=12, hidden_size=768, attention_dropout_rate=0.1, ): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.out = nn.Linear(hidden_size, hidden_size) self.attn_dropout = nn.Dropout(attention_dropout_rate) self.proj_dropout = nn.Dropout(attention_dropout_rate) self.softmax = nn.Softmax(dim=-1)
[docs] def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3)
[docs] def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) weights = None return attention_output, weights
[docs] class MLP(nn.Module): def __init__(self, # config hidden_size=768, mlp_dim=3072, dropout_rate=0.1, ): super().__init__() self.fc1 = nn.Linear(hidden_size, mlp_dim) self.fc2 = nn.Linear(mlp_dim, hidden_size) self.act_fn = nn.GELU() self.dropout = nn.Dropout(dropout_rate) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6)
[docs] def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x
[docs] class Embeddings(nn.Module): """Construct the embeddings from patch, position embeddings.""" def __init__( self, patch_size, img_size, num_layers, # resnet num layers: (3, 4, 9) width_factor=1, # resnet width factor: 1 in_channels=3, hidden_size=768, # embedding output dim dropout_rate=0.1, # embedding dropout rate ): super().__init__() self.hybrid = True img_size = _pair(img_size) grid_size = (int(img_size[0] / patch_size[0]), int(img_size[1] / patch_size[1])) patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1]) patch_size_real = (patch_size[0] * 16, patch_size[1] * 16) n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1]) if self.hybrid: self.hybrid_model = ResNetV2(block_units=num_layers, width_factor=width_factor) in_channels = self.hybrid_model.width * 16 self.patch_embeddings = nn.Conv2d(in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size) self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size)) self.dropout = nn.Dropout(dropout_rate)
[docs] def forward(self, x): if self.hybrid: x, features = self.hybrid_model(x) else: features = None x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2)) x = x.flatten(2) # (B, hidden. n_patches) x = x.transpose(-1, -2) # (B, n_patches, hidden) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings, features
[docs] class TransformerEncoderBlock(nn.Module): def __init__( self, hidden_size=768, mlp_dim=3072, dropout_rate=0.1 ): super().__init__() self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) self.ffn = MLP( hidden_size, mlp_dim, dropout_rate, ) self.attn = Attention()
[docs] def forward(self, x): skip = x x = self.attention_norm(x) x, weights = self.attn(x) x = x + skip skip = x x = self.ffn_norm(x) x = self.ffn(x) x = x + skip return x, weights
[docs] class TransformerEncoder(nn.Module): def __init__( self, num_layers=12, hidden_size=768, mlp_dim=3072, dropout_rate=0.1, ): super().__init__() self.layer = nn.ModuleList() self.encoder_norm = nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(num_layers): layer = TransformerEncoderBlock(hidden_size, mlp_dim, dropout_rate) self.layer.append(layer)
[docs] def forward(self, hidden_states): attn_weights = [] for layer_block in self.layer: hidden_states, _ = layer_block(hidden_states) encoded = self.encoder_norm(hidden_states) return encoded, attn_weights
[docs] class Transformer(nn.Module): def __init__( self, in_channels, img_size, patch_size=(16, 16), ): super().__init__() self.embeddings = Embeddings( patch_size=patch_size, img_size=img_size, in_channels=in_channels, num_layers=(3, 4, 9), ) self.encoder = TransformerEncoder( num_layers=12, hidden_size=768, )
[docs] def forward(self, input_ids): embedding_output, features = self.embeddings(input_ids) encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden) return encoded, attn_weights, features
[docs] class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) bn = nn.BatchNorm2d(out_channels) super(Conv2dReLU, self).__init__(conv, bn, relu)
[docs] class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels=0, use_batchnorm=True, ): super().__init__() self.conv1 = Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.up = nn.UpsamplingBilinear2d(scale_factor=2)
[docs] def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x
[docs] class SegmentationHead(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() super().__init__(conv2d, upsampling)
[docs] class TransUNetDecoder(nn.Module): def __init__( self, decoder_channels, hidden_size, n_skip, skip_channels, head_channels=512, ): super().__init__() self.conv_more = Conv2dReLU( hidden_size, head_channels, kernel_size=3, padding=1, use_batchnorm=True, ) in_channels = [head_channels] + list(decoder_channels[:-1]) out_channels = decoder_channels if n_skip != 0: for i in range(4-n_skip): # re-select the skip channels according to n_skip skip_channels[3-i]=0 else: skip_channels=[0, 0, 0, 0] blocks = [ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) ] self.blocks = nn.ModuleList(blocks) self.n_skip = n_skip
[docs] def forward(self, hidden_states, features=None): B, n_patch, hidden = hidden_states.size() # (B, n_patch, hidden) -> (B, h, w, hidden) h, w = int(math.sqrt(n_patch)), int(math.sqrt(n_patch)) x = hidden_states.permute(0, 2, 1) x = x.contiguous().view(B, hidden, h, w) x = self.conv_more(x) for i, decoder_block in enumerate(self.blocks): if features is not None: skip = features[i] if (i < self.n_skip) else None else: skip = None x = decoder_block(x, skip=skip) return x
# TODO: ND version
[docs] class TransUNet(nn.Module): def __init__( self, in_channels=3, img_size=224, num_classes=2, zero_head=False, decoder_channels=[256, 128, 64, 16], hidden_size=768, n_skip=3, skip_channels=[512, 256, 64, 16], patch_size=(16, 16), ): super().__init__() self.num_classes = num_classes self.zero_head = zero_head self.transformer = Transformer( in_channels, img_size, patch_size) self.decoder = TransUNetDecoder( decoder_channels, hidden_size, n_skip, skip_channels, ) self.segmentation_head = SegmentationHead( in_channels=decoder_channels[-1], out_channels=num_classes, kernel_size=3, )
[docs] def forward(self, x): if x.size()[1] == 1: x = x.repeat(1, 3, 1, 1) x, _, features = self.transformer(x) # (B, n_patch, hidden) x = self.decoder(x, features) logits = self.segmentation_head(x) return logits