Source code for fusionlab.segmentation.resunet.resunet

import torch
import torch.nn as nn
from fusionlab.segmentation.base import SegmentationModel
from fusionlab.utils import autopad
from fusionlab.layers.factories import ConvND, ConvT, BatchNorm



[docs] class ResUNet(SegmentationModel): def __init__( self, cin, num_cls, base_dim=64, spatial_dims=2 ): super().__init__() self.num_cls = num_cls self.encoder = Encoder(cin, base_dim, spatial_dims) self.bridger = Bridger() self.decoder = Decoder(base_dim, spatial_dims) self.head = Head(base_dim, num_cls, spatial_dims)
[docs] class Encoder(nn.Module): def __init__(self, cin, base_dim, spatial_dims=2): super().__init__() dims = [base_dim * (2 ** i) for i in range(4)] self.stem = Stem(cin, dims[0], spatial_dims) self.stage1 = ResConv(dims[0], dims[1], spatial_dims, stride=2) self.stage2 = ResConv(dims[1], dims[2], spatial_dims, stride=2) self.stage3 = ResConv(dims[2], dims[3], spatial_dims, stride=2)
[docs] def forward(self, x): s0 = self.stem(x) s1 = self.stage1(s0) s2 = self.stage2(s1) s3 = self.stage3(s2) return [s0, s1, s2, s3]
[docs] class Decoder(nn.Module):
[docs] def __init__(self, base_dim, spatial_dims=2): """ Base UNet decoder Args: base_dim (int): output dim of deepest stage output or input channels """ super().__init__() dims = [base_dim*(2**i) for i in range(4)] self.d3 = DecoderBlock(dims[3], dims[2], spatial_dims) self.d2 = DecoderBlock(dims[2], dims[1], spatial_dims) self.d1 = DecoderBlock(dims[1], dims[0], spatial_dims)
[docs] def forward(self, x): s0, s1, s2, s3 = x x = self.d3(s3, s2) x = self.d2(x, s1) x = self.d1(x, s0) return x
[docs] class DecoderBlock(nn.Module): def __init__(self, cin, cout, spatial_dims=2): super().__init__() self.upsample = ConvT(spatial_dims, cin, cout, 2, stride=2) self.conv = ResConv(cout*2, cout, spatial_dims, 1)
[docs] def forward(self, x1, x2): x1 = self.upsample(x1) x = torch.cat([x1, x2], dim=1) return self.conv(x)
[docs] class Bridger(nn.Module): def __init__(self): super().__init__()
[docs] def forward(self, x): outputs = [nn.Identity()(i) for i in x] return outputs
[docs] class Stem(nn.Module): def __init__(self, cin, cout, spatial_dims=2): super().__init__() self.conv = nn.Sequential( ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)), BatchNorm(spatial_dims, cout), nn.ReLU(), ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)), ) self.skip = nn.Sequential( ConvND(spatial_dims, cin, cout, 3, padding=autopad(3)), )
[docs] def forward(self, x): return self.conv(x) + self.skip(x)
[docs] class ResConv(nn.Module): def __init__(self, cin, cout, spatial_dims=2, stride=1): super().__init__() self.conv = nn.Sequential( BatchNorm(spatial_dims, cin), nn.ReLU(), ConvND(spatial_dims, cin, cout, 3, stride, padding=autopad(3)), BatchNorm(spatial_dims, cout), nn.ReLU(), ConvND(spatial_dims, cout, cout, 3, padding=autopad(3)), ) self.skip = nn.Sequential( ConvND(spatial_dims, cin, cout, 3, stride=stride, padding=autopad(3)), BatchNorm(spatial_dims, cout), )
[docs] def forward(self, x): return self.conv(x) + self.skip(x)
if __name__ == '__main__': H = W = 224 cout = 64 inputs = torch.normal(0, 1, (1, 3, H, W)) model = ResUNet(3, 100, cout) output = model(inputs) print(output.shape) dblock = DecoderBlock(64, 128) inputs2 = torch.normal(0, 1, (1, 128, H, W)) inputs1 = torch.normal(0, 1, (1, 64, H//2, W//2)) outputs = dblock(inputs1, inputs2) print(outputs.shape) encoder = Encoder(3, cout) outputs = encoder(inputs) for o in outputs: print(o.shape) decoder = Decoder(cout) outputs = decoder(outputs) print("Encoder + Decoder ", outputs.shape) stem = Stem(3, cout) outputs = stem(inputs) print(outputs.shape) assert list(outputs.shape) == [1, cout, H, W] resconv = ResConv(3, cout, stride=1) outputs = resconv(inputs) print(outputs.shape) assert list(outputs.shape) == [1, cout, H, W] resconv = ResConv(3, cout, stride=2) outputs = resconv(inputs) print(outputs.shape) assert list(outputs.shape) == [1, cout, H//2, W//2] print("3D ResUNet") D = H = W = 64 cout = 32 inputs = torch.rand(1, 3, D, H, W) model = ResUNet(3, 100, cout, spatial_dims=3) output = model(inputs) print(output.shape) print("1D ResUNet") L = 64 cout = 32 inputs = torch.rand(1, 3, L) model = ResUNet(3, 100, cout, spatial_dims=1) output = model(inputs) print(output.shape)