Source code for fusionlab.segmentation.tfbase

from tensorflow.keras import Model

[docs] class TFSegmentationModel(Model): """ Base PyTorch class of the segmentation model with Encoder, Bridger, Decoder, Head """
[docs] def call(self, x, training=None): """ Args: x: input tensor training: flag for BatchNormalization and Dropout, whether the layer should behave in training mode or in inference mode Returns: """ features = self.encoder(x, training) feature_fusion = self.bridger(features, training) decoder_output = self.decoder(feature_fusion, training) output = self.head(decoder_output, training) return output