diff --git a/code/utils/losses.py b/code/utils/losses.py index 94a0bac..b55e4b0 100644 --- a/code/utils/losses.py +++ b/code/utils/losses.py @@ -169,12 +169,16 @@ def _one_hot_encoder(self, input_tensor): def _dice_loss(self, score, target): target = target.float() smooth = 1e-5 - intersect = torch.sum(score * target) - y_sum = torch.sum(target * target) - z_sum = torch.sum(score * score) - loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) - loss = 1 - loss - return loss + # Compute per-sample dice to match CrossEntropyLoss mean reduction + batch_size = score.shape[0] + loss = 0.0 + for i in range(batch_size): + intersect = torch.sum(score[i] * target[i]) + y_sum = torch.sum(target[i] * target[i]) + z_sum = torch.sum(score[i] * score[i]) + sample_loss = 1 - (2 * intersect + smooth) / (z_sum + y_sum + smooth) + loss += sample_loss + return loss / batch_size def forward(self, inputs, target, weight=None, softmax=False): if softmax: