From 3dcbac4e791491c9328924e427289cab1f842aa7 Mon Sep 17 00:00:00 2001 From: Vishnu Kannaujia Date: Sun, 14 Jun 2026 21:09:25 -0700 Subject: [PATCH 1/2] Use compact [1,-2,1] kernel for BendingEnergyLoss second derivatives The previous implementation differentiated the central first-order helper twice, yielding a wide [1, 0, -2, 0, 1] / 4 stencil for pure second derivatives that spans four voxels per axis. This is less accurate than the standard compact [1, -2, 1] stencil and forced the validation "all spatial dims > 4". Switch to compact stencils computed directly from pred: - Pure d^2/dx_i^2 uses x[i+1] - 2 * x[i] + x[i-1]. - Mixed d^2/(dx_i dx_j) uses (x[i+1,j+1] - x[i+1,j-1] - x[i-1,j+1] + x[i-1,j-1]) / 4. Both span three voxels per axis, so the spatial-size validation relaxes to "> 2", matching DiffusionLoss. The public API (__init__(normalize, reduction), forward(pred)) and normalize semantics are unchanged. For quadratic inputs both old and new stencils return the exact analytic second derivative (= 2 for f(x) = x^2), so the existing TEST_CASES expected values (0.0, 4.0, 100.0) are invariant under this change; only test_ill_shape is updated to trigger on shape 2 instead of 4, and a new TEST_CASES row exercises the relaxed guard on shape (1, 3, 3, 3, 3). Fixes #5939. Signed-off-by: Vishnu Kannaujia --- monai/losses/deform.py | 76 +++++++++++++++++----- tests/losses/deform/test_bending_energy.py | 7 +- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 37e4468d4b..235f5a6b27 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -44,9 +44,58 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor: return (x[slicing_s] - x[slicing_e]) / 2.0 +def spatial_gradient_squared(x: torch.Tensor, dim_1: int, dim_2: int) -> torch.Tensor: + """ + Calculate the second-order partial derivative of ``x`` with respect to spatial dims + ``dim_1`` and ``dim_2`` using compact central finite differences. + + For ``dim_1 == dim_2`` the pure second derivative uses the ``[1, -2, 1]`` stencil: + ``d2x[i] = x[i+1] - 2 * x[i] + x[i-1]``. + + For ``dim_1 != dim_2`` the mixed partial uses the compact 4-point stencil: + ``d2x[i, j] = (x[i+1, j+1] - x[i+1, j-1] - x[i-1, j+1] + x[i-1, j-1]) / 4``. + + Every spatial dimension is sliced to ``[1:-1]`` so the output shape is independent of + ``(dim_1, dim_2)``; this lets terms be summed together. Requires ``x.shape[d] > 2`` + for every spatial dim ``d``. + + Args: + x: the shape should be BCH(WD). + dim_1: first spatial dimension index. + dim_2: second spatial dimension index. + + Returns: + Tensor with batch and channel axes preserved and every spatial axis sliced to + ``[1:-1]``. + """ + slice_inner = slice(1, -1) + slice_plus = slice(2, None) + slice_minus = slice(None, -2) + slice_all = slice(None) + + def _idx(overrides: dict) -> list: + out: list = [slice_all, slice_all] + for d in range(2, x.ndim): + out.append(overrides.get(d, slice_inner)) + return out + + if dim_1 == dim_2: + return x[_idx({dim_1: slice_plus})] - 2 * x[_idx({})] + x[_idx({dim_1: slice_minus})] + return ( + x[_idx({dim_1: slice_plus, dim_2: slice_plus})] + - x[_idx({dim_1: slice_plus, dim_2: slice_minus})] + - x[_idx({dim_1: slice_minus, dim_2: slice_plus})] + + x[_idx({dim_1: slice_minus, dim_2: slice_minus})] + ) / 4.0 + + class BendingEnergyLoss(_Loss): """ - Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference. + Calculate the bending energy based on second-order differentiation of ``pred``. + + Pure second derivatives use the compact ``[1, -2, 1]`` stencil; mixed partials use a + compact 4-point central scheme. Both span three voxels per axis, so each spatial + dimension of ``pred`` only needs to be greater than 2. For more information, see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb. @@ -79,41 +128,36 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. ValueError: When ``pred`` is not 3-d, 4-d or 5-d. - ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): - if pred.shape[-i - 1] <= 4: - raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " f"does not match number of spatial dimensions, {pred.ndim - 2}" ) - # first order gradient - first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] - # spatial dimensions in a shape suited for broadcasting below if self.normalize: spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) energy = torch.tensor(0) - for dim_1, g in enumerate(first_order_gradient): - dim_1 += 2 + for dim_1 in range(2, pred.ndim): + d2 = spatial_gradient_squared(pred, dim_1, dim_1) if self.normalize: - g *= pred.shape[dim_1] / spatial_dims - energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 - else: - energy = energy + spatial_gradient(g, dim_1) ** 2 + d2 = d2 * (pred.shape[dim_1] ** 2 / spatial_dims) + energy = energy + d2**2 for dim_2 in range(dim_1 + 1, pred.ndim): + d2_mixed = spatial_gradient_squared(pred, dim_1, dim_2) if self.normalize: - energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 - else: - energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 + d2_mixed = d2_mixed * (pred.shape[dim_1] * pred.shape[dim_2] / spatial_dims) + energy = energy + 2 * d2_mixed**2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/tests/losses/deform/test_bending_energy.py b/tests/losses/deform/test_bending_energy.py index 2e8ab32dbd..5e713b3e47 100644 --- a/tests/losses/deform/test_bending_energy.py +++ b/tests/losses/deform/test_bending_energy.py @@ -23,6 +23,7 @@ TEST_CASES = [ [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + [{}, {"pred": torch.ones((1, 3, 3, 3, 3), device=device)}, 0.0], [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], [ {"normalize": False}, @@ -64,11 +65,11 @@ def test_ill_shape(self): with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 5, 4, 5))) + loss.forward(torch.ones((1, 3, 5, 2, 5))) with self.assertRaisesRegex(ValueError, "All spatial dimensions"): - loss.forward(torch.ones((1, 3, 5, 5, 4))) + loss.forward(torch.ones((1, 3, 5, 5, 2))) # number of vector components unequal to number of spatial dims with self.assertRaisesRegex(ValueError, "Number of vector components"): From 51e60dac348006c1993c39e9981361809d5d6371 Mon Sep 17 00:00:00 2001 From: Vishnu Kannaujia Date: Sun, 14 Jun 2026 22:05:30 -0700 Subject: [PATCH 2/2] Initialize bending-energy accumulator as float on pred.device The compact pure-derivative stencil added in this branch (`x[i+1] - 2 * x[i] + x[i-1]`) intentionally avoids the `/2.0` factor of the old first-order helper. That means an integer-dtype `pred` (e.g. the arange-squared inputs in ``test_shape_5``: a 1-d input with no mixed terms to force float promotion via the `/4.0` factor) now stays Long all the way through the accumulator, and `torch.mean(energy)` raises: RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long Initialize the accumulator explicitly as a float and bind it to `pred.device` so it broadcasts against integer-dtype inputs and matches GPU tensors. Also matches the device pattern already used a few lines above for `spatial_dims`. `DiffusionLoss` is left unchanged: its first-order helper still divides by 2.0, so its accumulator is always promoted to float by the first addition; updating it here would expand the scope of #5939. Signed-off-by: Vishnu Kannaujia --- monai/losses/deform.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 235f5a6b27..80da8bafcc 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -147,7 +147,12 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: if self.normalize: spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) - energy = torch.tensor(0) + # Initialize on pred.device so a GPU `pred` does not get added to a CPU + # accumulator, and as a float so an integer-dtype `pred` still produces a + # floating-point energy (the compact pure-derivative stencil has no + # division, so a Long input would otherwise propagate as Long and fail + # `torch.mean` at the reduction step). + energy = torch.tensor(0.0, device=pred.device) for dim_1 in range(2, pred.ndim): d2 = spatial_gradient_squared(pred, dim_1, dim_1) if self.normalize: