From d3f6cf31b076f56d0ca51afccd28ab19e1301fcc Mon Sep 17 00:00:00 2001 From: Bryce Adelstein Lelbach Date: Sun, 28 Jun 2026 17:09:11 +0000 Subject: [PATCH] eigh_py: reject output-object deferral in the correctness check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A submission can return cheap placeholder tensors from custom_kernel (timed) and defer the real eigendecomposition into check_implementation (untimed): the checker promotes the outputs to FP64 with value.double(), and if the output is a torch.Tensor subclass with __torch_function__ — or even a plain tensor with a per-instance .double()/.detach() override — that promotion runs the genuine solve outside the timed region. Confirmed live on the eigh B200 leaderboard (fabricated ~17 us, all tests passing). Two changes to reference.py: 1. Require an EXACT plain torch.Tensor output (type(value) is torch.Tensor), not merely isinstance() — the latter admits any subclass. 2. Promote to FP64 through an override-proof path (_as_plain_fp64): strip to a base-class view and call the UNBOUND torch.Tensor.detach / .as_subclass / .double, never the bound value.detach()/.double(). Bound calls dispatch through the object and a plain tensor still carries a per-instance __dict__, so an instance-attribute override would otherwise fire; unbound calls go through the type and bypass it. (The bound .detach() was the gap a padded 'plain-tensor .detach override' variant used to survive an earlier draft of this fix; the unbound form closes it.) Now the residual math always runs on data the submission cannot re-point, so the real work cannot be deferred out of the timed region. --- problems/linalg/eigh_py/reference.py | 47 ++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/problems/linalg/eigh_py/reference.py b/problems/linalg/eigh_py/reference.py index 0cf1cc1a..95b1f025 100644 --- a/problems/linalg/eigh_py/reference.py +++ b/problems/linalg/eigh_py/reference.py @@ -13,8 +13,31 @@ _SORT_RTOL_FACTOR = 100.0 +def _as_plain_fp64(value: torch.Tensor) -> torch.Tensor: + # HARDENING (b): promote to FP64 through a path the output object cannot + # override. Calling value.double() / value.to(...) dispatches through the + # object's own type, so a torch.Tensor subclass with __torch_function__ can + # intercept the promotion and return whatever it likes (e.g. lazily run the + # genuine solve here, untimed). Strip to an exact base torch.Tensor FIRST via + # as_subclass(torch.Tensor) -- which yields a plain-Tensor view with no + # __torch_function__ override -- and only then promote with the unbound base + # method torch.Tensor.double, so the residual math operates on data the + # submission cannot re-point. + # + # Every step uses the UNBOUND base-class method (torch.Tensor.detach / + # .as_subclass / .double), never the bound form value.detach(): a bound call + # dispatches through the object and a plain torch.Tensor still carries a + # per-instance __dict__, so an instance attribute override of .detach()/.to() + # would otherwise fire here and lazily run the genuine solve untimed even + # though type(value) is exactly torch.Tensor. Unbound calls go through the + # type and bypass any per-instance override; detach() also drops autograd + # wrapping. + plain = torch.Tensor.as_subclass(torch.Tensor.detach(value), torch.Tensor) + return torch.Tensor.double(plain) + + def _matrix_l1_norm(value: torch.Tensor) -> torch.Tensor: - return torch.linalg.matrix_norm(value.double(), ord=1, dim=(-2, -1)) + return torch.linalg.matrix_norm(_as_plain_fp64(value), ord=1, dim=(-2, -1)) def _property_rtol(n: int, factor: float) -> float: @@ -276,8 +299,20 @@ def ref_kernel(data: input_t) -> output_t: def _check_tensor(name: str, value: torch.Tensor, shape: tuple[int, ...], device: torch.device) -> str | None: - if not isinstance(value, torch.Tensor): - return f"{name} must be a torch.Tensor" + # HARDENING (a): require an EXACT plain torch.Tensor, not a subclass. A + # torch.Tensor subclass can override __torch_function__ / .double() / .to() + # and so substitute what the FP64 residual math below sees -- e.g. return + # cheap placeholders from custom_kernel (timed) that lazily run the real + # solve only when the checker promotes them to FP64 (untimed recheck). The + # output contract is a plain FP32 tensor; isinstance() would admit any + # subclass, so gate on the exact type. + if type(value) is not torch.Tensor: + if not isinstance(value, torch.Tensor): + return f"{name} must be a torch.Tensor" + return ( + f"{name} must be a plain torch.Tensor, got subclass " + f"{type(value).__module__}.{type(value).__qualname__}" + ) if value.shape != shape: return f"{name} shape must be {shape}, got {tuple(value.shape)}" if value.dtype != torch.float32: @@ -328,9 +363,9 @@ def check_implementation(data: input_t, output: output_t) -> tuple[bool, str]: if not good: return False, message - a_check = a.double() - q_check = q.double() - values_check = values.double() + a_check = _as_plain_fp64(a) + q_check = _as_plain_fp64(q) + values_check = _as_plain_fp64(values) aq = a_check @ q_check ql = q_check * values_check.unsqueeze(-2) if not torch.isfinite(aq).all().item() or not torch.isfinite(ql).all().item():