Source code for fusionlab.functional.dice
from typing import Tuple
import torch
from fusionlab.configs import EPS
[docs]
def dice_score(pred: torch.Tensor,
target: torch.Tensor,
dims: Tuple[int, ...]=None) -> torch.Tensor:
"""
Computes the dice score
Args:
pred: (N, C, *)
target: (N, C, *)
dims: dimensions to sum over
"""
assert pred.size() == target.size()
intersection = torch.sum(pred * target, dim=dims)
cardinality = torch.sum(pred + target, dim=dims)
return (2.0 * intersection) / cardinality.clamp(min=EPS)