Source code for fusionlab.losses.tversky.tftversky

import tensorflow as tf
from einops import rearrange
from fusionlab.configs import EPS

__all__ = ["TFTverskyLoss"]

BINARY_MODE = "binary"
MULTICLASS_MODE = "multiclass"


[docs] class TFTverskyLoss(tf.keras.losses.Loss):
[docs] def __init__(self, alpha, beta, mode="binary", # binary, multiclass log_loss=False, from_logits=False, ): """ Implementation of Dice loss for image segmentation task. It supports "binary", "multiclass" ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py Args: alpha: controls the penalty for false positives(FP). beta: controls the penalty for false negatives(FN). mode: Metric mode {'binary', 'multiclass'} log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` from_logits: If True assumes input is raw logits """ super().__init__() self.alpha = alpha self.beta = beta self.mode = mode self.from_logits = from_logits self.log_loss = log_loss
[docs] def call(self, y_true, y_pred): """ :param y_true: (N, *) :param y_pred: (N, *, C) :return: scalar """ y_true_shape = y_true.shape.as_list() y_pred_shape = y_pred.shape.as_list() assert y_true_shape[0] == y_pred_shape[0] num_classes = y_pred_shape[-1] axis = [0] if self.from_logits: # get [0..1] class probabilities if self.mode == MULTICLASS_MODE: y_pred = tf.nn.softmax(y_pred, axis=-1) else: y_pred = tf.nn.sigmoid(y_pred) if self.mode == BINARY_MODE: y_true = rearrange(y_true, "... -> (...) 1") y_pred = rearrange(y_pred, "... -> (...) 1") elif self.mode == MULTICLASS_MODE: y_true = tf.cast(y_true, tf.int32) y_true = tf.one_hot(y_true, num_classes) y_true = rearrange(y_true, "... C -> (...) C") y_pred = rearrange(y_pred, "... C -> (...) C") else: AssertionError("Not implemented") scores = tf_tversky_score(y_pred, tf.cast(y_true, y_pred.dtype), self.alpha, self.beta, axis=axis) if self.log_loss: loss = -tf.math.log(tf.clip_by_value(scores, clip_value_min=1e-7, clip_value_max=scores.dtype.max)) else: loss = 1.0 - scores return tf.math.reduce_mean(loss)
def tf_tversky_score(pred, target, alpha, beta, axis=None): """ Shape: - pred: :math:`(N, *, C)` where :math:`*` means any number of additional dimensions - target: :math:`(N, *, C)`, same shape as the input - Output: scalar. """ intersection = tf.reduce_sum(pred * target, axis=axis) fp = tf.reduce_sum(pred * (1. - target), axis) fn = tf.reduce_sum((1. - pred) * target, axis) denominator = intersection + alpha * fp + beta * fn denominator = tf.clip_by_value(denominator, clip_value_min=EPS, clip_value_max=denominator.dtype.max) return intersection / denominator if __name__ == "__main__": print("Multiclass") pred = tf.convert_to_tensor([[ [1., 2., 3., 4.], [2., 6., 4., 4.], [9., 6., 3., 4.] ]]) pred = rearrange(pred, "N C H -> N H C") true = tf.convert_to_tensor([[2, 1, 0, 2]]) loss_fn = TFTverskyLoss(0.5, 0.5, "multiclass", from_logits=True) loss = loss_fn(true, pred) print(float(loss)) print("Binary") pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) pred = tf.reshape(pred, [1, 2, 2, 1]) true = tf.convert_to_tensor([0, 1, 0, 1]) true = tf.reshape(true, [1, 2, 2]) loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True) loss = loss_fn(true, pred) print(float(loss)) print("Binary Log loss") pred = tf.convert_to_tensor([0.4, 0.2, 0.3, 0.5]) pred = tf.reshape(pred, [1, 2, 2, 1]) true = tf.convert_to_tensor([0, 1, 0, 1]) true = tf.reshape(true, [1, 2, 2]) loss_fn = TFTverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True) loss = loss_fn(true, pred) print(float(loss))