Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,58 @@ def spatial_gradient(x: torch.Tensor, dim: int) -> torch.Tensor:
return (x[slicing_s] - x[slicing_e]) / 2.0


def spatial_gradient_squared(x: torch.Tensor, dim_1: int, dim_2: int) -> torch.Tensor:
"""
Calculate the second-order partial derivative of ``x`` with respect to spatial dims
``dim_1`` and ``dim_2`` using compact central finite differences.

For ``dim_1 == dim_2`` the pure second derivative uses the ``[1, -2, 1]`` stencil:
``d2x[i] = x[i+1] - 2 * x[i] + x[i-1]``.

For ``dim_1 != dim_2`` the mixed partial uses the compact 4-point stencil:
``d2x[i, j] = (x[i+1, j+1] - x[i+1, j-1] - x[i-1, j+1] + x[i-1, j-1]) / 4``.

Every spatial dimension is sliced to ``[1:-1]`` so the output shape is independent of
``(dim_1, dim_2)``; this lets terms be summed together. Requires ``x.shape[d] > 2``
for every spatial dim ``d``.

Args:
x: the shape should be BCH(WD).
dim_1: first spatial dimension index.
dim_2: second spatial dimension index.

Returns:
Tensor with batch and channel axes preserved and every spatial axis sliced to
``[1:-1]``.
"""
slice_inner = slice(1, -1)
slice_plus = slice(2, None)
slice_minus = slice(None, -2)
slice_all = slice(None)

def _idx(overrides: dict) -> list:
out: list = [slice_all, slice_all]
for d in range(2, x.ndim):
out.append(overrides.get(d, slice_inner))
return out

if dim_1 == dim_2:
return x[_idx({dim_1: slice_plus})] - 2 * x[_idx({})] + x[_idx({dim_1: slice_minus})]
return (
x[_idx({dim_1: slice_plus, dim_2: slice_plus})]
- x[_idx({dim_1: slice_plus, dim_2: slice_minus})]
- x[_idx({dim_1: slice_minus, dim_2: slice_plus})]
+ x[_idx({dim_1: slice_minus, dim_2: slice_minus})]
) / 4.0


class BendingEnergyLoss(_Loss):
"""
Calculate the bending energy based on second-order differentiation of ``pred`` using central finite difference.
Calculate the bending energy based on second-order differentiation of ``pred``.

Pure second derivatives use the compact ``[1, -2, 1]`` stencil; mixed partials use a
compact 4-point central scheme. Both span three voxels per axis, so each spatial
dimension of ``pred`` only needs to be greater than 2.

For more information,
see https://github.com/Project-MONAI/tutorials/blob/main/modules/bending_energy_diffusion_loss_notes.ipynb.
Expand Down Expand Up @@ -79,41 +128,41 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
ValueError: When any spatial dimension of ``pred`` has size less than or equal to 4.
ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.

"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 4:
raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}")
if pred.shape[-i - 1] <= 2:
raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
f"does not match number of spatial dimensions, {pred.ndim - 2}"
)

# first order gradient
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

# spatial dimensions in a shape suited for broadcasting below
if self.normalize:
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))

energy = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
# Initialize on pred.device so a GPU `pred` does not get added to a CPU
# accumulator, and as a float so an integer-dtype `pred` still produces a
# floating-point energy (the compact pure-derivative stencil has no
# division, so a Long input would otherwise propagate as Long and fail
# `torch.mean` at the reduction step).
energy = torch.tensor(0.0, device=pred.device)
for dim_1 in range(2, pred.ndim):
d2 = spatial_gradient_squared(pred, dim_1, dim_1)
if self.normalize:
g *= pred.shape[dim_1] / spatial_dims
energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2
else:
energy = energy + spatial_gradient(g, dim_1) ** 2
d2 = d2 * (pred.shape[dim_1] ** 2 / spatial_dims)
energy = energy + d2**2
for dim_2 in range(dim_1 + 1, pred.ndim):
d2_mixed = spatial_gradient_squared(pred, dim_1, dim_2)
if self.normalize:
energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2
else:
energy = energy + 2 * spatial_gradient(g, dim_2) ** 2
d2_mixed = d2_mixed * (pred.shape[dim_1] * pred.shape[dim_2] / spatial_dims)
energy = energy + 2 * d2_mixed**2

if self.reduction == LossReduction.MEAN.value:
energy = torch.mean(energy) # the batch and channel average
Expand Down
7 changes: 4 additions & 3 deletions tests/losses/deform/test_bending_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

TEST_CASES = [
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
[{}, {"pred": torch.ones((1, 3, 3, 3, 3), device=device)}, 0.0],
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0],
[
{"normalize": False},
Expand Down Expand Up @@ -64,11 +65,11 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 4, 5, 5), device=device))
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 4, 5)))
loss.forward(torch.ones((1, 3, 5, 2, 5)))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 5, 4)))
loss.forward(torch.ones((1, 3, 5, 5, 2)))

# number of vector components unequal to number of spatial dims
with self.assertRaisesRegex(ValueError, "Number of vector components"):
Expand Down
Loading