Source code for fusionlab.segmentation.base

import torch.nn as nn

[docs] class SegmentationModel(nn.Module): """ Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head """
[docs] def forward(self, x): features = self.encoder(x) feature_fusion = self.bridger(features) decoder_output = self.decoder(feature_fusion) output = self.head(decoder_output) return output
[docs] class HFSegmentationModel(nn.Module): """ Base Hugginface-pytoch model wrapper class of the segmentation model """ def __init__(self, model, num_cls=None, loss_fct=nn.CrossEntropyLoss()): super().__init__() self.net = model if 'num_cls' in model.__dict__.keys(): self.num_cls = model.num_cls else: self.num_cls = num_cls assert self.num_cls is not None, "num_cls is not defined" self.loss_fct = loss_fct
[docs] def forward(self, x, labels=None): logits = self.net(x) # Forward pass the model if labels is not None: # logits => [BATCH, NUM_CLS] # labels => [BATCH] loss = self.loss_fct(logits, labels) # Calculate loss else: loss = None # return dictionary for hugginface trainer return {'loss':loss, 'logits':logits, 'hidden_states':None}
if __name__ == '__main__': import torch from fusionlab.segmentation import ResUNet H = W = 224 cout = 5 inputs = torch.normal(0, 1, (1, 3, H, W)) model = ResUNet(3, cout, 64) hf_model = HFSegmentationModel(model, cout) output = hf_model(inputs) assert list(output.keys()) == ['loss', 'logits', 'hidden_states'] print(output['logits'].shape) assert list(output['logits'].shape) == [1, cout, H, W]