Source code for fusionlab.metrics.dicescore.dice

import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from fusionlab.functional import dice_score

BINARY_MODE = "binary"
MULTICLASS_MODE = "multiclass"

[docs] class DiceScore(nn.Module):
[docs] def __init__( self, mode="multiclass", # binary, multiclass from_logits=True, reduction="none", # mean, none ): """ Computer dice score for binary or multiclass input Args: mode: "binary" or "multiclass" from_logits: if True, assumes input is raw logits reduction: "mean" or "none", if "none" returns dice score for each channels, else returns mean """ super().__init__() self.mode = mode self.from_logits = from_logits self.reduction = reduction
[docs] def forward(self, y_pred, y_true) -> torch.Tensor: """ :param y_pred: (N, C, *) :param y_true: (N, *) :return: scalar """ assert y_true.size(0) == y_pred.size(0) num_classes = y_pred.size(1) dims = (0, 2) # dimensions to sum over (N, C, *) if self.from_logits: # get [0..1] class probabilities if self.mode == MULTICLASS_MODE: y_pred = F.softmax(y_pred, dim=1) else: y_pred = torch.sigmoid(y_pred) if self.mode == BINARY_MODE: y_true = rearrange(y_true, "N ... -> N 1 (...)") y_pred = rearrange(y_pred, "N 1 ... -> N 1 (...)") elif self.mode == MULTICLASS_MODE: y_pred = rearrange(y_pred, "N C ... -> N C (...)") y_true = F.one_hot(y_true, num_classes) # (N, *) -> (N, *, C) y_true = rearrange(y_true, "N ... C -> N C (...)") else: AssertionError("Not implemented") scores = dice_score(y_pred, y_true.type_as(y_pred), dims=dims) if self.reduction == "none": return scores else: return scores.mean()
JaccardScore = DiceScore