Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions problems/linalg/eigh_py/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,31 @@
_SORT_RTOL_FACTOR = 100.0


def _as_plain_fp64(value: torch.Tensor) -> torch.Tensor:
# HARDENING (b): promote to FP64 through a path the output object cannot
# override. Calling value.double() / value.to(...) dispatches through the
# object's own type, so a torch.Tensor subclass with __torch_function__ can
# intercept the promotion and return whatever it likes (e.g. lazily run the
# genuine solve here, untimed). Strip to an exact base torch.Tensor FIRST via
# as_subclass(torch.Tensor) -- which yields a plain-Tensor view with no
# __torch_function__ override -- and only then promote with the unbound base
# method torch.Tensor.double, so the residual math operates on data the
# submission cannot re-point.
#
# Every step uses the UNBOUND base-class method (torch.Tensor.detach /
# .as_subclass / .double), never the bound form value.detach(): a bound call
# dispatches through the object and a plain torch.Tensor still carries a
# per-instance __dict__, so an instance attribute override of .detach()/.to()
# would otherwise fire here and lazily run the genuine solve untimed even
# though type(value) is exactly torch.Tensor. Unbound calls go through the
# type and bypass any per-instance override; detach() also drops autograd
# wrapping.
plain = torch.Tensor.as_subclass(torch.Tensor.detach(value), torch.Tensor)
return torch.Tensor.double(plain)


def _matrix_l1_norm(value: torch.Tensor) -> torch.Tensor:
return torch.linalg.matrix_norm(value.double(), ord=1, dim=(-2, -1))
return torch.linalg.matrix_norm(_as_plain_fp64(value), ord=1, dim=(-2, -1))


def _property_rtol(n: int, factor: float) -> float:
Expand Down Expand Up @@ -276,8 +299,20 @@ def ref_kernel(data: input_t) -> output_t:


def _check_tensor(name: str, value: torch.Tensor, shape: tuple[int, ...], device: torch.device) -> str | None:
if not isinstance(value, torch.Tensor):
return f"{name} must be a torch.Tensor"
# HARDENING (a): require an EXACT plain torch.Tensor, not a subclass. A
# torch.Tensor subclass can override __torch_function__ / .double() / .to()
# and so substitute what the FP64 residual math below sees -- e.g. return
# cheap placeholders from custom_kernel (timed) that lazily run the real
# solve only when the checker promotes them to FP64 (untimed recheck). The
# output contract is a plain FP32 tensor; isinstance() would admit any
# subclass, so gate on the exact type.
if type(value) is not torch.Tensor:
if not isinstance(value, torch.Tensor):
return f"{name} must be a torch.Tensor"
return (
f"{name} must be a plain torch.Tensor, got subclass "
f"{type(value).__module__}.{type(value).__qualname__}"
)
if value.shape != shape:
return f"{name} shape must be {shape}, got {tuple(value.shape)}"
if value.dtype != torch.float32:
Expand Down Expand Up @@ -328,9 +363,9 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]:
if not good:
return False, message

a_check = a.double()
q_check = q.double()
values_check = values.double()
a_check = _as_plain_fp64(a)
q_check = _as_plain_fp64(q)
values_check = _as_plain_fp64(values)
aq = a_check @ q_check
ql = q_check * values_check.unsqueeze(-2)
if not torch.isfinite(aq).all().item() or not torch.isfinite(ql).all().item():
Expand Down