import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from collections import OrderedDict
import math
# source code: https://github.com/Beckschen/TransUNet
# TODO: extract layer to utils
[docs]
class StdConv2d(nn.Conv2d):
[docs]
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-5)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
[docs]
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
padding=1, bias=bias, groups=groups)
[docs]
def conv1x1(cin, cout, stride=1, bias=False):
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
padding=0, bias=bias)
[docs]
class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block."""
def __init__(self, cin, cout=None, cmid=None, stride=1):
super().__init__()
cout = cout or cin
cmid = cmid or cout//4
self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
self.conv1 = conv1x1(cin, cmid, bias=False)
self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
self.conv3 = conv1x1(cmid, cout, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = stride != 1 or cin != cout # downsample on skip connection
if self.downsample:
# Projection also with pre-activation according to paper.
self.downsample = conv1x1(cin, cout, stride, bias=False)
self.gn_proj = nn.GroupNorm(cout, cout)
[docs]
def forward(self, x):
# Residual branch
residual = x
if self.downsample:
residual = self.downsample(x)
residual = self.gn_proj(residual)
x = self.relu(self.gn1(self.conv1(x)))
x = self.relu(self.gn2(self.conv2(x)))
x = self.gn3(self.conv3(x))
x = self.relu(residual + x)
return x
# TODO: Extract to encoder
[docs]
class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode."""
def __init__(self, block_units, width_factor):
super().__init__()
width = int(64 * width_factor)
self.width = width
self.root = nn.Sequential(OrderedDict([
('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
('gn', nn.GroupNorm(32, width, eps=1e-6)),
('relu', nn.ReLU(inplace=True)),
]))
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.body = nn.Sequential(OrderedDict([
('block1', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
))),
('block2', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
))),
('block3', nn.Sequential(OrderedDict(
[('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
))),
]))
[docs]
def forward(self, x, return_features=True):
features = []
b, _, in_size, _ = x.size()
x = self.root(x)
features.append(x)
x = self.pool(x)
for i in range(len(self.body)-1):
x = self.body[i](x)
right_size = int(in_size / 4 / (i+1))
if x.size()[2] != right_size:
pad = right_size - x.size()[2]
assert pad < 3 and pad > 0, f"x {x.size()} should {right_size}"
feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
else:
feat = x
features.append(feat)
x = self.body[-1](x)
if return_features:
return x, features[::-1]
else:
return x
# TODO: extract to layers
[docs]
class Attention(nn.Module):
def __init__(
self,
num_attention_heads=12,
hidden_size=768,
attention_dropout_rate=0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / self.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.out = nn.Linear(hidden_size, hidden_size)
self.attn_dropout = nn.Dropout(attention_dropout_rate)
self.proj_dropout = nn.Dropout(attention_dropout_rate)
self.softmax = nn.Softmax(dim=-1)
[docs]
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
[docs]
def forward(self, hidden_states):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = self.softmax(attention_scores)
attention_probs = self.attn_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
attention_output = self.out(context_layer)
attention_output = self.proj_dropout(attention_output)
weights = None
return attention_output, weights
[docs]
class MLP(nn.Module):
def __init__(self,
# config
hidden_size=768,
mlp_dim=3072,
dropout_rate=0.1,
):
super().__init__()
self.fc1 = nn.Linear(hidden_size, mlp_dim)
self.fc2 = nn.Linear(mlp_dim, hidden_size)
self.act_fn = nn.GELU()
self.dropout = nn.Dropout(dropout_rate)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
[docs]
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
[docs]
class Embeddings(nn.Module):
"""Construct the embeddings from patch, position embeddings."""
def __init__(
self,
patch_size,
img_size,
num_layers, # resnet num layers: (3, 4, 9)
width_factor=1, # resnet width factor: 1
in_channels=3,
hidden_size=768, # embedding output dim
dropout_rate=0.1, # embedding dropout rate
):
super().__init__()
self.hybrid = True
img_size = _pair(img_size)
grid_size = (int(img_size[0] / patch_size[0]), int(img_size[1] / patch_size[1]))
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
if self.hybrid:
self.hybrid_model = ResNetV2(block_units=num_layers, width_factor=width_factor)
in_channels = self.hybrid_model.width * 16
self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
out_channels=hidden_size,
kernel_size=patch_size,
stride=patch_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
[docs]
def forward(self, x):
if self.hybrid:
x, features = self.hybrid_model(x)
else:
features = None
x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
x = x.flatten(2) # (B, hidden. n_patches)
x = x.transpose(-1, -2) # (B, n_patches, hidden)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, features
[docs]
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d(out_channels)
super(Conv2dReLU, self).__init__(conv, bn, relu)
[docs]
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
skip_channels=0,
use_batchnorm=True,
):
super().__init__()
self.conv1 = Conv2dReLU(
in_channels + skip_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.conv2 = Conv2dReLU(
out_channels,
out_channels,
kernel_size=3,
padding=1,
use_batchnorm=use_batchnorm,
)
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
[docs]
def forward(self, x, skip=None):
x = self.up(x)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.conv1(x)
x = self.conv2(x)
return x
[docs]
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
super().__init__(conv2d, upsampling)
[docs]
class TransUNetDecoder(nn.Module):
def __init__(
self,
decoder_channels,
hidden_size,
n_skip,
skip_channels,
head_channels=512,
):
super().__init__()
self.conv_more = Conv2dReLU(
hidden_size,
head_channels,
kernel_size=3,
padding=1,
use_batchnorm=True,
)
in_channels = [head_channels] + list(decoder_channels[:-1])
out_channels = decoder_channels
if n_skip != 0:
for i in range(4-n_skip): # re-select the skip channels according to n_skip
skip_channels[3-i]=0
else:
skip_channels=[0, 0, 0, 0]
blocks = [
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
]
self.blocks = nn.ModuleList(blocks)
self.n_skip = n_skip
[docs]
def forward(self, hidden_states, features=None):
B, n_patch, hidden = hidden_states.size() # (B, n_patch, hidden) -> (B, h, w, hidden)
h, w = int(math.sqrt(n_patch)), int(math.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)
x = x.contiguous().view(B, hidden, h, w)
x = self.conv_more(x)
for i, decoder_block in enumerate(self.blocks):
if features is not None:
skip = features[i] if (i < self.n_skip) else None
else:
skip = None
x = decoder_block(x, skip=skip)
return x
# TODO: ND version
[docs]
class TransUNet(nn.Module):
def __init__(
self,
in_channels=3,
img_size=224,
num_classes=2,
zero_head=False,
decoder_channels=[256, 128, 64, 16],
hidden_size=768,
n_skip=3,
skip_channels=[512, 256, 64, 16],
patch_size=(16, 16),
):
super().__init__()
self.num_classes = num_classes
self.zero_head = zero_head
self.transformer = Transformer(
in_channels,
img_size,
patch_size)
self.decoder = TransUNetDecoder(
decoder_channels,
hidden_size,
n_skip,
skip_channels,
)
self.segmentation_head = SegmentationHead(
in_channels=decoder_channels[-1],
out_channels=num_classes,
kernel_size=3,
)
[docs]
def forward(self, x):
if x.size()[1] == 1:
x = x.repeat(1, 3, 1, 1)
x, _, features = self.transformer(x) # (B, n_patch, hidden)
x = self.decoder(x, features)
logits = self.segmentation_head(x)
return logits