diff --git a/problems/linalg/eigh_py/reference.py b/problems/linalg/eigh_py/reference.py index 0cf1cc1a..95b1f025 100644 --- a/problems/linalg/eigh_py/reference.py +++ b/problems/linalg/eigh_py/reference.py @@ -13,8 +13,31 @@ _SORT_RTOL_FACTOR = 100.0 +def _as_plain_fp64(value: torch.Tensor) -> torch.Tensor: + # HARDENING (b): promote to FP64 through a path the output object cannot + # override. Calling value.double() / value.to(...) dispatches through the + # object's own type, so a torch.Tensor subclass with __torch_function__ can + # intercept the promotion and return whatever it likes (e.g. lazily run the + # genuine solve here, untimed). Strip to an exact base torch.Tensor FIRST via + # as_subclass(torch.Tensor) -- which yields a plain-Tensor view with no + # __torch_function__ override -- and only then promote with the unbound base + # method torch.Tensor.double, so the residual math operates on data the + # submission cannot re-point. + # + # Every step uses the UNBOUND base-class method (torch.Tensor.detach / + # .as_subclass / .double), never the bound form value.detach(): a bound call + # dispatches through the object and a plain torch.Tensor still carries a + # per-instance __dict__, so an instance attribute override of .detach()/.to() + # would otherwise fire here and lazily run the genuine solve untimed even + # though type(value) is exactly torch.Tensor. Unbound calls go through the + # type and bypass any per-instance override; detach() also drops autograd + # wrapping. + plain = torch.Tensor.as_subclass(torch.Tensor.detach(value), torch.Tensor) + return torch.Tensor.double(plain) + + def _matrix_l1_norm(value: torch.Tensor) -> torch.Tensor: - return torch.linalg.matrix_norm(value.double(), ord=1, dim=(-2, -1)) + return torch.linalg.matrix_norm(_as_plain_fp64(value), ord=1, dim=(-2, -1)) def _property_rtol(n: int, factor: float) -> float: @@ -276,8 +299,20 @@ def ref_kernel(data: input_t) -> output_t: def _check_tensor(name: str, value: torch.Tensor, shape: tuple[int, ...], device: torch.device) -> str | None: - if not isinstance(value, torch.Tensor): - return f"{name} must be a torch.Tensor" + # HARDENING (a): require an EXACT plain torch.Tensor, not a subclass. A + # torch.Tensor subclass can override __torch_function__ / .double() / .to() + # and so substitute what the FP64 residual math below sees -- e.g. return + # cheap placeholders from custom_kernel (timed) that lazily run the real + # solve only when the checker promotes them to FP64 (untimed recheck). The + # output contract is a plain FP32 tensor; isinstance() would admit any + # subclass, so gate on the exact type. + if type(value) is not torch.Tensor: + if not isinstance(value, torch.Tensor): + return f"{name} must be a torch.Tensor" + return ( + f"{name} must be a plain torch.Tensor, got subclass " + f"{type(value).__module__}.{type(value).__qualname__}" + ) if value.shape != shape: return f"{name} shape must be {shape}, got {tuple(value.shape)}" if value.dtype != torch.float32: @@ -328,9 +363,9 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: if not good: return False, message - a_check = a.double() - q_check = q.double() - values_check = values.double() + a_check = _as_plain_fp64(a) + q_check = _as_plain_fp64(q) + values_check = _as_plain_fp64(values) aq = a_check @ q_check ql = q_check * values_check.unsqueeze(-2) if not torch.isfinite(aq).all().item() or not torch.isfinite(ql).all().item():