Source code for fusionlab.losses.tversky.tversky

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from fusionlab.configs import EPS

__all__ = ["TverskyLoss"]

BINARY_MODE = "binary"
MULTICLASS_MODE = "multiclass"


[docs] class TverskyLoss(nn.Module):
[docs] def __init__( self, alpha, beta, mode="multiclass", # binary, multiclass log_loss=False, from_logits=True, ): """ Implementation of Tversky loss for image segmentation task. It supports "binary", "multiclass" ref: https://github.com/kornia/kornia/blob/master/kornia/losses/tversky.py ref: https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py Args: alpha: controls the penalty for false positives(FP). beta: controls the penalty for false negatives(FN). mode: Metric mode {'binary', 'multiclass'} log_loss: If True, loss computed as `-log(dice)`; otherwise `1 - dice` from_logits: If True assumes input is raw logits """ super().__init__() self.alpha = alpha self.beta = beta self.mode = mode self.from_logits = from_logits self.log_loss = log_loss
[docs] def forward(self, y_pred, y_true): """ :param y_pred: (N, C, *) :param y_true: (N, *) :return: scalar """ assert y_true.size(0) == y_pred.size(0) num_classes = y_pred.size(1) dims = (0, 2) # (N, C, HW) if self.from_logits: # get [0..1] class probabilities if self.mode == MULTICLASS_MODE: y_pred = F.softmax(y_pred, dim=1) else: y_pred = torch.sigmoid(y_pred) if self.mode == BINARY_MODE: y_true = rearrange(y_true, "N ... -> N 1 (...)") y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") elif self.mode == MULTICLASS_MODE: y_pred = rearrange(y_pred, "N C ... -> N C (...)") y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) y_true = rearrange(y_true, "N ... C -> N C (...)") else: AssertionError("Not implemented") scores = tversky_score(y_pred, y_true.type_as(y_pred), self.alpha, self.beta, dims=dims) if self.log_loss: loss = -torch.log(scores.clamp_min(EPS)) else: loss = 1.0 - scores return loss.mean()
def tversky_score(pred, target, alpha, beta, dims): """ Shape: - pred: :math:`(N, C, *)` - target: :math:`(N, C, *)` - Output: scalar. """ assert pred.size() == target.size() intersection = torch.sum(pred * target, dim=dims) fp = torch.sum(pred * (1. - target), dims) fn = torch.sum((1. - pred) * target, dims) denominator = intersection + alpha * fp + beta * fn return intersection / denominator.clamp(min=EPS) if __name__ == "__main__": print("multiclass") pred = torch.tensor([[ [1., 2., 3., 4.], [2., 6., 4., 4.], [9., 6., 3., 4.] ]]).unsqueeze(-1) true = torch.tensor([[2, 1, 0, 2]]).view(1, 4).unsqueeze(-1) loss_fn = TverskyLoss(0.5, 0.5, "multiclass", from_logits=True) loss = loss_fn(pred, true) print(loss.item()) print("Binary") pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True) loss = loss_fn(pred, true) print(loss.item()) print("Binary Logloss") pred = torch.tensor([0.4, 0.2, 0.3, 0.5]).reshape(1, 1, 2, 2) true = torch.tensor([0, 1, 0, 1]).reshape(1, 2, 2) loss_fn = TverskyLoss(0.5, 0.5, "binary", from_logits=True, log_loss=True) loss = loss_fn(pred, true) print(loss.item())