Skip to content

eigh_py: reject output-object deferral in the correctness check#161

Open
robobryce wants to merge 1 commit into
gpu-mode:mainfrom
robobryce:eigh-reject-output-deferral
Open

eigh_py: reject output-object deferral in the correctness check#161
robobryce wants to merge 1 commit into
gpu-mode:mainfrom
robobryce:eigh-reject-output-deferral

Conversation

@robobryce

Copy link
Copy Markdown
Contributor

Problem

The benchmark times custom_kernel but the correctness check runs outside the timed region. A submission can exploit that split: return cheap placeholder tensors from custom_kernel (fast, timed) and do the real eigendecomposition lazily inside check_implementation when it promotes the outputs to FP64.

check_implementation promotes with value.double(). That dispatches through the output object, so:

  • a torch.Tensor subclass with __torch_function__ can intercept .double() and run the genuine solve there, and
  • even a plain torch.Tensor carries a per-instance __dict__, so an instance-attribute override of .double() / .detach() does the same without being a subclass.

Confirmed live on the B200 eigh leaderboard: a __torch_function__ deferral was accepted with a fabricated ~17 µs time, all tests passing.

Fix (reference.py)

  1. Require an exact plain tensor. Gate on type(value) is torch.Tensor, not isinstance(...) — the latter admits any subclass.

  2. Promote through an override-proof path (_as_plain_fp64): strip the output to a base-class view and use the unbound torch.Tensor.detach / .as_subclass / .double, never the bound value.detach() / value.double(). Bound calls dispatch through the object (and hit a per-instance override); unbound calls go through the type and bypass it.

The second point matters specifically because a bound value.detach() is itself an interceptable call — a "plain-tensor .detach override" variant survives a fix that strips via value.detach(). Using unbound torch.Tensor.detach(value) closes that.

After this, the residual math always runs on data the submission cannot re-point, so real work can't be deferred out of the timed region. Honest kernels (which return plain FP32 tensors) are unaffected.

A submission can return cheap placeholder tensors from custom_kernel (timed)
and defer the real eigendecomposition into check_implementation (untimed): the
checker promotes the outputs to FP64 with value.double(), and if the output is a
torch.Tensor subclass with __torch_function__ — or even a plain tensor with a
per-instance .double()/.detach() override — that promotion runs the genuine
solve outside the timed region. Confirmed live on the eigh B200 leaderboard
(fabricated ~17 us, all tests passing).

Two changes to reference.py:

1. Require an EXACT plain torch.Tensor output (type(value) is torch.Tensor), not
   merely isinstance() — the latter admits any subclass.

2. Promote to FP64 through an override-proof path (_as_plain_fp64): strip to a
   base-class view and call the UNBOUND torch.Tensor.detach / .as_subclass /
   .double, never the bound value.detach()/.double(). Bound calls dispatch
   through the object and a plain tensor still carries a per-instance __dict__,
   so an instance-attribute override would otherwise fire; unbound calls go
   through the type and bypass it. (The bound .detach() was the gap a padded
   'plain-tensor .detach override' variant used to survive an earlier draft of
   this fix; the unbound form closes it.)

Now the residual math always runs on data the submission cannot re-point, so the
real work cannot be deferred out of the timed region.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants