From 3308ecf36f07bf9c2c10ea9ea7e7e500bc24edc6 Mon Sep 17 00:00:00 2001 From: rtmalikian Date: Thu, 18 Jun 2026 06:11:31 -0700 Subject: [PATCH] fix: use per-sample dice computation in DiceLoss to match CrossEntropyLoss mean reduction (#74) The DiceLoss._dice_loss() computed the dice coefficient using torch.sum across all batch elements simultaneously. While the dice coefficient is a ratio (bounded 0-1), computing it over the entire batch instead of per-sample means that samples with different characteristics (common in medical imaging) can mask each other's contribution to the loss. This changes _dice_loss() to compute the dice coefficient per-sample and return the batch average, consistent with CrossEntropyLoss's default 'mean' reduction behavior. When batch samples have different anatomy or pathology, the old approach could produce loss values that don't reflect individual sample quality. For example, a well-predicted sample can dilute the loss from a poorly-predicted sample in the same batch. Per-sample computation ensures each sample contributes equally to the training signal. --- code/utils/losses.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/code/utils/losses.py b/code/utils/losses.py index 94a0bac..b55e4b0 100644 --- a/code/utils/losses.py +++ b/code/utils/losses.py @@ -169,12 +169,16 @@ def _one_hot_encoder(self, input_tensor): def _dice_loss(self, score, target): target = target.float() smooth = 1e-5 - intersect = torch.sum(score * target) - y_sum = torch.sum(target * target) - z_sum = torch.sum(score * score) - loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) - loss = 1 - loss - return loss + # Compute per-sample dice to match CrossEntropyLoss mean reduction + batch_size = score.shape[0] + loss = 0.0 + for i in range(batch_size): + intersect = torch.sum(score[i] * target[i]) + y_sum = torch.sum(target[i] * target[i]) + z_sum = torch.sum(score[i] * score[i]) + sample_loss = 1 - (2 * intersect + smooth) / (z_sum + y_sum + smooth) + loss += sample_loss + return loss / batch_size def forward(self, inputs, target, weight=None, softmax=False): if softmax: