From 28471e12c4b88f96bd3b6324789949421bdde0f8 Mon Sep 17 00:00:00 2001 From: Sinatras Date: Fri, 1 May 2026 02:05:01 +0300 Subject: [PATCH] fix(detector): catch LAST_CALL_REPLAY via flag/counter-only gate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a fourth pass to detect_last_call_replay covering the case where the saved-output replay is gated entirely on a module-level counter or flag, with no identity comparison against the input — the smoking gun is that the early-return path doesn't reference the input parameter at all. Pattern caught: _OUT = None _N = 0 def custom_kernel(data): global _OUT, _N _N += 1 if _N > 1 and _OUT is not None: return _OUT _OUT = data.clone() return _OUT The new pass requires: a 'global' declaration in the entrypoint, at least one declared global mutated by the body (counter increment or flag toggle), an If whose test references those mutated globals but no parameter names, no calls in the if-body, and a return rooted on a mutated global. All four guards are needed — without them legitimate workspace allocators (if _buf is None: _buf = torch.empty_like(data)) and shape-dispatch helpers would false-positive. Verified locally: catches counter_skip and first_call_init_lock; does not trigger on workspace_alloc_lazy, plain_kernel, shape_dispatch_legit, or global_compile_cache_legit. --- kernelguard.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/kernelguard.py b/kernelguard.py index f086847..78c2fcc 100644 --- a/kernelguard.py +++ b/kernelguard.py @@ -2459,6 +2459,80 @@ def _has_ver(expr: ast.AST | None) -> bool: ), }] + # Quaternary pass: flag/counter-gated replay (no identity comparison). + # Pattern: + # _OUT = None + # _N = 0 + # def custom_kernel(data): + # global _OUT, _N + # _N += 1 + # if _N > 1 and _OUT is not None: + # return _OUT + # _OUT = data.clone() + # return _OUT + # The earlier passes all require an identity-based test that pairs a + # captured slot against the input. This pattern has no such pairing — + # the gate is purely on a module-level counter or flag, and the early + # return drops the input on the floor. + param_names = {a.arg for a in node.args.args} + param_names.update(a.arg for a in getattr(node.args, "posonlyargs", []) or []) + param_names.update(a.arg for a in getattr(node.args, "kwonlyargs", []) or []) + + global_names: set[str] = set() + for n in ast.walk(node): + if isinstance(n, ast.Global): + global_names.update(n.names) + + if global_names: + global_assigned: set[str] = set() + for child in ast.walk(node): + if isinstance(child, ast.Assign): + for target in child.targets: + root = _ast_root_name(target) + if root and root in global_names: + global_assigned.add(root) + elif isinstance(child, ast.AugAssign): + root = _ast_root_name(child.target) + if root and root in global_names: + global_assigned.add(root) + + if global_assigned: + for child in ast.walk(node): + if not isinstance(child, ast.If): + continue + if _body_has_calls(child.body): + continue + + test_names = _expr_names(child.test) + # Test must reference at least one mutated global, and must + # NOT reference any input parameter — that's the smoking gun: + # the early-return decision ignores the input entirely. + if not (test_names & global_assigned): + continue + if test_names & param_names: + continue + + saved_return = False + for stmt in child.body: + if not isinstance(stmt, ast.Return) or stmt.value is None: + continue + rroot = _ast_root_name(stmt.value) + if rroot and rroot in global_assigned: + saved_return = True + break + if not saved_return: + continue + + return [{ + "pattern": "LAST_CALL_REPLAY", + "severity": "critical", + "evidence": ( + f"{entrypoint_name} returns saved output gated only " + f"by a module-level counter/flag — the early-return " + f"path does not reference the input" + ), + }] + return []