Source code for fusionlab.functional.dice

import torch
from fusionlab.configs import EPS


[docs]def dice_score(pred, target, dims=None): """ Shape: - pred: :math:`(N, C, *)` - target: :math:`(N, C, *)` - Output: scalar. """ assert pred.size() == target.size() intersection = torch.sum(pred * target, dim=dims) cardinality = torch.sum(pred + target, dim=dims) return (2.0 * intersection) / cardinality.clamp(min=EPS)