diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index a1f1ff5747c..e96a3e2593a 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -352,6 +352,28 @@ def emit_ceil_div( return P.to_int_or_vid(out_slot) +def emit_floordiv( + P: "MLXProgramBuilder", + a: "IntOrVid", + b: "IntOrVid", +) -> "IntOrVid": + """Emit ``a // b`` (floor division), folding when both operands are + static (issue #20554). Used for ``cond = (M - 1) // sort_cutoff``: + 0 when M <= sort_cutoff (unsorted), >= 1 when M > sort_cutoff (sorted). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + FloorDivideIntNode, + IntOrVid, + ) + + if not a.is_vid and not b.is_vid: + return IntOrVid.from_literal(a.literal // b.literal) + + _, out_slot = P.make_tmp_value_slot() + P.emit(FloorDivideIntNode(a=a, b=b, out=P.slot_to_vid(out_slot))) + return P.to_int_or_vid(out_slot) + + def emit_if_else( P: "MLXProgramBuilder", cond: "IntOrVid", diff --git a/backends/mlx/custom_ops.py b/backends/mlx/custom_ops.py index 5605b59c543..22e9d944355 100644 --- a/backends/mlx/custom_ops.py +++ b/backends/mlx/custom_ops.py @@ -14,7 +14,7 @@ can execute efficiently but may not have direct PyTorch equivalents. """ -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor @@ -285,7 +285,7 @@ def gather_mm( b: Tensor, # [E, K, N] or [..., K, N] rhs_indices: Optional[Tensor] = None, # Expert selection indices lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices - sorted_indices: bool = False, + sorted_indices: Optional[Tensor] = None, # issue #20554: 0-d int; None/0 = unsorted ) -> Tensor: """ Gather matrix multiply — matches mlx::core::gather_mm semantics exactly. @@ -295,6 +295,10 @@ def gather_mm( For MoE: a=[N_tokens, 1, K], b=[E, K, out], rhs_indices=[N_tokens] → output=[N_tokens, 1, out]. Caller squeezes dim -2. + + sorted_indices is layout-only (correctness contract for the MLX + kernel at runtime); numerics are identical either way, so the eager + reference ignores it (issue #20554, section 2). """ if rhs_indices is not None: b_sel = b[rhs_indices] @@ -309,7 +313,7 @@ def gather_mm_fake( b: Tensor, rhs_indices: Optional[Tensor] = None, lhs_indices: Optional[Tensor] = None, - sorted_indices: bool = False, + sorted_indices: Optional[Tensor] = None, ) -> Tensor: # Matches MLX: output = indices.shape + [M, N] # For simplicity, use matmul shape rules after gather @@ -334,7 +338,7 @@ def gather_qmm( group_size: int = 32, bits: int = 4, mode: str = "affine", - sorted_indices: bool = False, + sorted_indices: Optional[Tensor] = None, # issue #20554: 0-d int; None/0 = unsorted ) -> Tensor: """ Gather quantized matrix multiply — matches mlx::core::gather_qmm semantics. @@ -343,6 +347,8 @@ def gather_qmm( For MoE: x=[N_tokens, 1, K], w=[E, out, K_packed], rhs_indices=[N_tokens] → output=[N_tokens, 1, out]. Caller squeezes dim -2. + + sorted_indices is layout-only; ignored here (see gather_mm docstring). """ # Eager fallback: gather, dequantize, matmul if rhs_indices is not None: @@ -381,7 +387,7 @@ def gather_qmm_fake( group_size: int = 32, bits: int = 4, mode: str = "affine", - sorted_indices: bool = False, + sorted_indices: Optional[Tensor] = None, ) -> Tensor: # Matches MLX: output = indices.shape + [M, N] M = x.shape[-2] @@ -397,19 +403,15 @@ def gather_qmm_fake( def sample( logits: Tensor, temperature: Tensor, - top_k: Tensor, top_p: Tensor, seed: Optional[Tensor] = None, ) -> Tensor: """ - Gumbel-max sampling from softmax(logits / temperature), with top-k and - top-p (nucleus) filtering. + Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus). logits: [B, vocab] temperature: scalar float tensor (runtime input). temperature <= 0 is greedy: return argmax(logits) directly (matches the device, which branches on temperature > 0). - top_k: scalar int tensor. It is clipped to the vocab size; using the - max int default keeps every token. top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every token, i.e. it is off. seed: scalar int tensor or None @@ -426,14 +428,6 @@ def sample( return torch.argmax(logits, dim=-1) # whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties). scaled = logits.float() / temperature - - k = min(int(top_k.item()), scaled.shape[-1]) - s_scaled, _ = torch.sort(scaled, dim=-1, descending=True) - kth = s_scaled[..., k - 1 : k] - scaled = torch.where(scaled >= kth, scaled, scaled.new_tensor(float("-inf"))) - - # Apply top-p after top-k so the probabilities are renormalized over the - # top-k subset. probs = torch.softmax(scaled, dim=-1) s_probs, _ = torch.sort(probs, dim=-1, descending=True) cum = torch.cumsum(s_probs, dim=-1) @@ -452,5 +446,77 @@ def sample( @torch.library.register_fake("mlx::sample") -def sample_fake(logits, temperature, top_k, top_p, seed=None): +def sample_fake(logits, temperature, top_p, seed=None): return logits.new_empty(logits.shape[:-1], dtype=torch.long) + + +# --------------------------------------------------------------------- +# Issue #20554: runtime MoE expert-sort for decode (MLX backend) +# --------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::moe_gather_inputs", mutates_args=()) +def moe_gather_inputs( + x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Branch on M on purpose — this is the executable spec the lowering + handler (ops.py) mirrors branch-for-branch. Sorting is an invertible + permutation (identical numerics either way); the two paths exist for + the lowering's sake, not the math's.""" + N = x.shape[0] + if N > sort_cutoff: # SORTED path (handler: emit_sorted) + flat = expert_indices.flatten() + order = flat.argsort().to(torch.int32) + inv_order = order.argsort().to(torch.int32) + idx = flat[order].to(torch.int32) # [N*top_k] + x_input = x[(order // top_k).to(torch.int64)].unsqueeze(-2) # [N*top_k, 1, D] + sort_experts = torch.ones((), dtype=torch.int32) + else: # UNSORTED path (handler: emit_unsorted) + x_input = x.repeat_interleave(top_k, dim=0).unsqueeze(-2) # [N*top_k, 1, D] + idx = expert_indices.flatten().to(torch.int32) # [N*top_k] + sort_experts = torch.zeros((), dtype=torch.int32) + inv_order = torch.empty(0, dtype=torch.int32) # sentinel: never read + return x_input, idx, sort_experts, inv_order + + +@torch.library.register_fake("mlx::moe_gather_inputs") +def moe_gather_inputs_fake( + x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Must NOT branch on M (symbolic SymInt under export — data-dependent + control flow on it is illegal). One shape for all M: the sorted-path + shape for x_input/idx/inv_order (issue #20554, section 3).""" + N = x.shape[0] + D = x.shape[-1] + NK = N * top_k + x_input = x.new_empty((NK, 1, D)) + idx = expert_indices.new_empty((NK,), dtype=torch.int32) + sort_experts = x.new_empty((), dtype=torch.int32) + inv_order = x.new_empty((NK,), dtype=torch.int32) + return x_input, idx, sort_experts, inv_order + + +@torch.library.custom_op("mlx::moe_scatter_outputs", mutates_args=()) +def moe_scatter_outputs( + down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int +) -> Tensor: + down = down.squeeze(-2) # [N*top_k, H] + if sort_experts.item(): # prefill: scatter back (handler: emit_then) + down = down[inv_order] + # decode: no scatter (inv_order is the unread sentinel) (handler: emit_else) + return down.reshape(down.shape[0] // top_k, top_k, -1).clone() # [N, top_k, H] + # .clone(): avoids the aliasing-on-leaf-op issue opcheck flags for + # the no-op (unsorted) reshape path -- not in the issue's pseudo-code + # verbatim, added per torch.library.opcheck's aliasing requirement + # for custom ops (mutates_args=() means outputs must not alias inputs). + + +@torch.library.register_fake("mlx::moe_scatter_outputs") +def moe_scatter_outputs_fake( + down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int +) -> Tensor: + """Shape derived only from down + top_k -- no branching needed, no + dependency on inv_order's shape (issue #20554, section 3).""" + NK = down.shape[0] + H = down.shape[-1] + return down.new_empty((NK // top_k, top_k, H)) diff --git a/backends/mlx/llm/switch.py b/backends/mlx/llm/switch.py index 28d408cbd71..fc94814cf30 100644 --- a/backends/mlx/llm/switch.py +++ b/backends/mlx/llm/switch.py @@ -41,6 +41,7 @@ """ import logging +from typing import Optional import torch import torch.nn as nn @@ -131,12 +132,15 @@ def forward( self, x: torch.Tensor, indices: torch.Tensor, - sorted_indices: bool = False, + sorted_indices: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward without unsqueeze/squeeze — caller manages dimensions. Used by UnfusedMoEExperts which passes x as [N, 1, 1, D] and indices as [N, top_k] to handle all experts at once. + + issue #20554: sorted_indices: bool -> Optional[Tensor] (same swap + as gather_mm/gather_qmm). Passed straight through to those ops. """ if not self._packed: raise RuntimeError("SwitchLinear.pack() must be called before forward_raw.") @@ -193,6 +197,7 @@ def __init__( activation=None, bias: bool = False, fuse_gate_up: bool = False, + sort_cutoff: int = 1, ): super().__init__() if activation is None: @@ -201,6 +206,9 @@ def __init__( self.num_experts = num_experts self.intermediate_size = intermediate_size self.fuse_gate_up = fuse_gate_up + # issue #20554: static export-time threshold, compared against + # M=N inside moe_gather_inputs to decide sort/no-sort at runtime. + self.sort_cutoff = sort_cutoff if fuse_gate_up: self.gate_up_proj = SwitchLinear( @@ -223,7 +231,6 @@ def forward( expert_weights: torch.Tensor, expert_indices: torch.Tensor, top_k: int, - sort_experts: bool = False, ) -> torch.Tensor: """Forward pass through the gated MoE MLP. @@ -232,25 +239,20 @@ def forward( expert_weights: Routing weights [N, top_k] (already softmaxed) expert_indices: Expert assignments [N, top_k] top_k: Number of experts per token - sort_experts: Sort tokens by expert index for coalesced memory - access during prefill. No effect on decode (single token). Returns: Output tensor [N, D] + + issue #20554: sort/no-sort is now a RUNTIME decision (M vs + self.sort_cutoff) made inside moe_gather_inputs, rather than a + compile-time bool baked into the exported .pte. The `sort_experts` + arg is therefore gone from this signature (was: bool = False) -- + callers (e.g. mlx_source_transformations.py) configure the + threshold once via SwitchMLP(..., sort_cutoff=...) instead. """ - N = x.shape[0] - - if sort_experts: - flat_indices = expert_indices.flatten() - order = flat_indices.argsort().to(torch.int32) - inv_order = order.argsort().to(torch.int32) - sorted_idx = flat_indices[order].to(torch.int32) - x_sorted = x[(order // top_k).to(torch.int64)] - x_input = x_sorted.unsqueeze(-2) - idx = sorted_idx - else: - x_input = x.unsqueeze(-2).unsqueeze(-2) - idx = expert_indices + x_input, idx, sort_experts, inv_order = torch.ops.mlx.moe_gather_inputs( + x, expert_indices, top_k, self.sort_cutoff + ) if self.fuse_gate_up: gate_up = self.gate_up_proj(x_input, idx, sorted_indices=sort_experts) @@ -262,11 +264,7 @@ def forward( h = self.activation(gate) * up down = self.down_proj(h, idx, sorted_indices=sort_experts) - if sort_experts: - down = down.squeeze(-2) - down = down[inv_order].reshape(N, top_k, -1) - else: - down = down.squeeze(-2) + down = torch.ops.mlx.moe_scatter_outputs(down, sort_experts, inv_order, top_k) return (down * expert_weights.unsqueeze(-1)).sum(dim=-2) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 2a480d3a28b..549cc61494f 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -21,16 +21,18 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( emit_if_else, + emit_floordiv, emit_lifted_constant, emit_quantized_biases, emit_shape, + emit_sub_int, parse_dequant_node, to_mlx_qparams, torch_dtype_to_scalar_type, ) from executorch.backends.mlx.builder.op_registry import REGISTRY from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import IdSpace, IdType, Slot +from executorch.backends.mlx.builder.slot_manager import IdType, Slot from executorch.backends.mlx.serialization.mlx_graph_schema import ( AbsNode, AddIntNode, @@ -164,7 +166,6 @@ # The corresponding edge ops are automatically registered # For ops that are not in aten (e.g., dim order ops), directly register on exir_ops from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.passes.reinplace import _derive_edge_inplace_overload from torch.fx.node import Node _LEAKY_RELU_DEFAULT_NEGATIVE_SLOPE = 0.01 @@ -429,217 +430,6 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot: REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name)) -def _make_inplace_unary_handler(node_cls: Any, op_name: str): - """Create a handler for an in-place unary op (e.g. aten.exp_). - - These nodes are produced by the MLX reinplace pass (see passes.py), which - only rewrites a functional op to its in-place form when the input is a dead, - single-use temp. We bind the node's output slot to that input slot and emit - with out_tid == in_tid so MLX donates the input buffer at eval time (same - mechanism as SLICE_UPDATE/INDEX_COPY). If the input is not a reusable temp - (defensive — should not happen given the reinplace safety analysis), fall - back to allocating a fresh output slot. - """ - - def handler(P: MLXProgramBuilder, n: Node) -> Slot: - args = P.args(n) - require_args(args, 1, 1, op_name) - require_kwargs(P.kwargs(n), set(), op_name) - x = args[0] - input_node = n.args[0] - # Only alias when n produces a fresh temp (no pre-assigned slot). Graph - # outputs / mutable buffers already own an Output/MutableBuffer slot - # (from _make_io_slots) and must keep it, so fall back to functional for - # those — donation on a terminal output is worthless anyway (it's copied - # out). Also require the input to be a dead, single-use temp. - if ( - P.slot_manager.get_slot(n) is None - and isinstance(x, Slot) - and x.id_space == IdSpace.Temp - and isinstance(input_node, Node) - and len(input_node.users) == 1 - ): - # Reuse the dead input temp's slot as the output (out == in). The - # builder's slot-lifetime transfer (program_builder._mark_read) keeps - # this slot alive until n's own users are done. - P.set_slot(n, x) - P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(x))) - return x - out = P.make_or_get_slot(n) - P.emit(node_cls(x=P.slot_to_tid(x), out=P.slot_to_tid(out))) - return out - - handler.__name__ = f"_{op_name.replace('.', '_')}_handler" - handler.__doc__ = f"Handle {op_name} (in-place table-driven unary op)." - return handler - - -# Register in-place variants (e.g. aten.exp_) for every unary op MLX handles that -# has an aten in-place overload. REINPLACEABLE_UNARY_BASE_NAMES is the source of -# truth consumed by passes.py to build the reinplace pass's op set, so MLX has -# full control over exactly which ops get reinplaced (handlers exist for all of -# them, and nothing else — e.g. index_put is never included). -REINPLACEABLE_UNARY_BASE_NAMES: List[str] = [] -for _target, _node_cls, _op_name in _UNARY_OPS: - _base = _op_name.split(".")[-1] - _ip_packet = getattr(torch.ops.aten, _base + "_", None) - _ip_op = getattr(_ip_packet, "default", None) if _ip_packet is not None else None - if _ip_op is None: - continue - REGISTRY.register(target=[_ip_op])( - _make_inplace_unary_handler(_node_cls, _op_name + "_") - ) - REINPLACEABLE_UNARY_BASE_NAMES.append(_base) - - -def _inplace_alias_slot(P: MLXProgramBuilder, n: Node, a) -> Optional[Slot]: - """Return ``a``'s slot if it is safe to reuse it as ``n``'s output (out == in). - - The MLX reinplace pass only emits an in-place op when the mutated operand is - full-size and dtype-matching (the shape/dtype guard lives there, where it is - dynamic-shape/SymInt-safe). So this handler-side check just confirms ``a`` is - a reusable temp: ``n`` has no pre-assigned slot (not a graph output / mutable - buffer) and ``a`` is a single-use ``Temp`` tensor. (Reusing the slot is - runtime-correct regardless of shape — it is functional slot reuse; MLX only - donates the buffer when sizes are compatible.) Returns None otherwise. - """ - if P.slot_manager.get_slot(n) is not None: - return None - a_node = n.args[0] if n.args else None - if not ( - isinstance(a, Slot) - and a.id_space == IdSpace.Temp - and isinstance(a_node, Node) - and len(a_node.users) == 1 - ): - return None - return a - - -def _make_inplace_binary_handler(node_cls: Any, op_name: str): - """In-place binary handler (mul_/div_, no alpha): alias out == arg0 when safe. - - Produced by the MLX reinplace pass, which already guarantees arg0 is a - full-size, dtype-matching, single-use dead temp; the alias check here is - defensive and also handles the graph-output fallback. - """ - - def handler(P: MLXProgramBuilder, n: Node) -> Slot: - args = P.args(n) - require_args(args, 2, 2, op_name) - require_kwargs(P.kwargs(n), set(), op_name) - a, b = args[0], args[1] - alias = _inplace_alias_slot(P, n, a) - out = alias if alias is not None else P.make_or_get_slot(n) - P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) - if alias is not None: - P.set_slot(n, alias) - return out - - handler.__name__ = f"_{op_name.replace('.', '_')}_handler" - handler.__doc__ = f"Handle {op_name} (in-place table-driven binary op)." - return handler - - -def _make_inplace_addsub_handler(node_cls: Any, op_name: str): - """In-place add_/sub_ handler: handles the alpha kwarg and aliases out == arg0. - - ``alpha`` only scales the *other* operand (arg1), so it never blocks aliasing - arg0 (self); when ``alpha != 1`` we emit ``other * alpha`` into a temp first. - """ - - def handler(P: MLXProgramBuilder, n: Node) -> Slot: - args = P.args(n) - require_args(args, 2, 2, op_name) - require_kwargs(P.kwargs(n), {"alpha"}, op_name) - a, b = args[0], args[1] - alpha = P.kwargs(n).get("alpha", 1) - if alpha != 1: - input_meta = n.args[0].meta.get("val") - dtype = input_meta.dtype if input_meta is not None else torch.float32 - alpha_slot = emit_lifted_constant(P, alpha, dtype) - _, tmp = P.make_tmp_slot() - P.emit( - MultiplyNode( - a=P.slot_to_tid(b), - b=P.slot_to_tid(alpha_slot), - out=P.slot_to_tid(tmp), - ) - ) - b = tmp - alias = _inplace_alias_slot(P, n, a) - out = alias if alias is not None else P.make_or_get_slot(n) - P.emit(node_cls(a=P.slot_to_tid(a), b=P.slot_to_tid(b), out=P.slot_to_tid(out))) - if alias is not None: - P.set_slot(n, alias) - return out - - handler.__name__ = f"_{op_name.replace('.', '_')}_handler" - handler.__doc__ = f"Handle {op_name} (in-place add/sub op)." - return handler - - -# In-place binary handlers + the (base, overload) source of truth consumed by -# passes.py to build the binary reinplace op set. Restricted to dtype-preserving -# arithmetic Tensor overloads; the reinplace pass additionally guards that arg0 -# is full-size (no broadcast) before producing these in-place ops. -REINPLACEABLE_BINARY_BASE_OVERLOADS: List[Tuple[str, str]] = [] -for _ip_target, _ip_node_cls, _ip_name, _is_addsub in ( - (torch.ops.aten.add_.Tensor, AddNode, "aten.add_", True), - (torch.ops.aten.sub_.Tensor, SubtractNode, "aten.sub_", True), - (torch.ops.aten.mul_.Tensor, MultiplyNode, "aten.mul_", False), - (torch.ops.aten.div_.Tensor, DivideNode, "aten.div_", False), -): - _factory = ( - _make_inplace_addsub_handler if _is_addsub else _make_inplace_binary_handler - ) - REGISTRY.register(target=[_ip_target])(_factory(_ip_node_cls, _ip_name)) - REINPLACEABLE_BINARY_BASE_OVERLOADS.append((_ip_name.split(".")[-1][:-1], "Tensor")) - - -def _make_inplace_passthrough_handler(functional_handler): - """In-place handler that aliases out == self, then delegates to the op's - existing functional handler. - - These functional handlers (clamp, pow, gelu, relu, leaky_relu, hardtanh) - obtain their output slot via ``P.make_or_get_slot(n)`` and write it with the - last op they emit. By pre-binding ``n``'s slot to the dead ``self`` temp - before delegating, that final write becomes in-place (out == in) and MLX can - donate the buffer. When ``self`` is not a reusable temp (e.g. a graph - output), no pre-bind happens and the functional handler runs unchanged. - - The mutated ``self`` is always positional arg 0 for these ops, and every - op emitted before the output writer only *reads* ``self``, so the in-place - write (last) is safe. This "all reads of ``self`` happen before the final - write to out == self" ordering is a contract on each delegated functional - handler; the assertion below catches the easy-to-spot violation where a - handler stops using ``n``'s slot as its output, but a handler that reads - ``self`` *after* writing out would still silently corrupt — keep that - invariant in mind when editing clamp/pow/gelu/relu/leaky_relu/hardtanh. - """ - - def handler(P: MLXProgramBuilder, n: Node) -> Slot: - args = P.args(n) - self_slot = args[0] if args else None - alias = _inplace_alias_slot(P, n, self_slot) - if alias is not None: - P.set_slot(n, alias) - result = functional_handler(P, n) - # When we pre-bind out == self, the delegated handler must treat that - # slot as its output (write it last). Confirm it actually returned the - # aliased slot; otherwise the in-place aliasing silently did nothing. - assert alias is None or result is alias, ( - f"{getattr(functional_handler, '__name__', functional_handler)} did " - f"not use the aliased out==self slot as its output for {n}; in-place " - f"passthrough requires the delegated handler to write n's slot." - ) - return result - - handler.__name__ = "_inplace_passthrough_handler" - handler.__doc__ = "In-place passthrough (aliases out==self, delegates)." - return handler - - # --------------------------------------------------------------------------- # Numerical checks # --------------------------------------------------------------------------- @@ -1757,7 +1547,9 @@ def _split_with_sizes_handler(P: MLXProgramBuilder, n: Node) -> Slot: @REGISTRY.register(target=[torch.ops.mlx.gather_mm.default]) def _gather_mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Handle mlx::gather_mm — fused gather + matmul for MoE experts.""" - from executorch.backends.mlx.serialization.mlx_graph_schema import GatherMmNode + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + GatherMmNode, IntOrVid, ItemIntNode, + ) args = P.args(n) kwargs = P.kwargs(n) @@ -1766,7 +1558,20 @@ def _gather_mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: b = args[1] rhs_indices = args[2] if len(args) > 2 else kwargs.get("rhs_indices") lhs_indices = args[3] if len(args) > 3 else kwargs.get("lhs_indices") - sorted_indices = args[4] if len(args) > 4 else kwargs.get("sorted_indices", False) + sorted_indices = args[4] if len(args) > 4 else kwargs.get("sorted_indices") + # issue #20554: sorted_indices is now Optional[Tensor] (was bool). + # P.args(n) already resolved any FX Node to a Slot, so a runtime + # tensor arrives here as a Slot, not a Node (verified + # program_builder.py:194-222). Thread it to a Vid via ItemIntNode + # (same gesture as cond_val in _sample_handler:3709) -- slot_to_vid + # cannot be called directly on a Tensor-typed Slot + # (assert slot.id_type != IdType.Tensor, program_builder.py:303). + if isinstance(sorted_indices, Slot): + _, item_slot = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(sorted_indices), out=P.slot_to_vid(item_slot))) + sorted_indices_iov = P.to_int_or_vid(item_slot) + else: + sorted_indices_iov = IntOrVid.from_literal(0) out = P.make_or_get_slot(n) P.emit( @@ -1776,7 +1581,7 @@ def _gather_mm_handler(P: MLXProgramBuilder, n: Node) -> Slot: out=P.slot_to_tid(out), lhs_indices=P.slot_to_tid(lhs_indices) if lhs_indices is not None else None, rhs_indices=P.slot_to_tid(rhs_indices) if rhs_indices is not None else None, - sorted_indices=sorted_indices, + sorted_indices=sorted_indices_iov, ) ) return out @@ -1789,7 +1594,9 @@ def _gather_qmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: Converts TorchAO quantization format to MLX format (unsigned + biases) and emits a GatherQmmNode. """ - from executorch.backends.mlx.serialization.mlx_graph_schema import GatherQmmNode + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + GatherQmmNode, IntOrVid, ItemIntNode, + ) args = P.args(n) kwargs = P.kwargs(n) @@ -1804,7 +1611,14 @@ def _gather_qmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: group_size = args[7] if len(args) > 7 else kwargs.get("group_size", 32) bits = args[8] if len(args) > 8 else kwargs.get("bits", 4) mode = args[9] if len(args) > 9 else kwargs.get("mode", "affine") - sorted_indices = args[10] if len(args) > 10 else kwargs.get("sorted_indices", False) + sorted_indices = args[10] if len(args) > 10 else kwargs.get("sorted_indices") + # issue #20554: same Optional[Tensor]/Slot conversion as gather_mm. + if isinstance(sorted_indices, Slot): + _, item_slot = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(sorted_indices), out=P.slot_to_vid(item_slot))) + sorted_indices_iov = P.to_int_or_vid(item_slot) + else: + sorted_indices_iov = IntOrVid.from_literal(0) # Convert quantized weights to MLX format w_target, w_data = P.get_placeholder_target_and_tensor(w_node) @@ -1847,12 +1661,143 @@ def _gather_qmm_handler(P: MLXProgramBuilder, n: Node) -> Slot: group_size=group_size, bits=bits, mode=mode, - sorted_indices=sorted_indices, + sorted_indices=sorted_indices_iov, ) ) return out +@REGISTRY.register(target=[torch.ops.mlx.moe_gather_inputs.default]) +def _moe_gather_inputs_handler(P: MLXProgramBuilder, n: Node): + """Issue #20554. Node-for-node lowering of the moe_gather_inputs + eager reference (custom_ops.py): emit_sorted mirrors the N > + sort_cutoff branch, emit_unsorted mirrors the else branch. Both + pre-allocate every output slot and both branches write all of them + (the IfNode only selects a chain; downstream reads fixed slot ids). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + ArgsortNode, TakeNode, FloorDivideNode, ExpandDimsNode, + RepeatNode, IdCopyNode, IntOrVid, + ) + + args = P.args(n) + x, expert_indices = args[0], args[1] + top_k = args[2] # static int + sort_cutoff = args[3] # static int, consulted here at emission + + out_slots = P.make_or_get_slots(n) # (x_input, idx, sort_experts, inv_order) + + m_iov = emit_shape(P, n.args[0], x, end_dim=1)[0] # M = N (token count) + + def emit_sorted(): + # ArgsortNode x2, FloorDivideNode (order // top_k), TakeNode, + # ExpandDimsNode, a 0-d const "1" -> write all four out_slots + _, order_slot = P.make_tmp_slot() + P.emit(ArgsortNode(x=P.slot_to_tid(expert_indices), out=P.slot_to_tid(order_slot), axis=0)) + + _, inv_order_slot = P.make_tmp_slot() + P.emit(ArgsortNode(x=P.slot_to_tid(order_slot), out=P.slot_to_tid(inv_order_slot), axis=0)) + + _, idx_slot = P.make_tmp_slot() + P.emit(TakeNode(x=P.slot_to_tid(expert_indices), out=P.slot_to_tid(idx_slot), + index=P.slot_to_tid(order_slot), axis=0)) + + top_k_const = emit_lifted_constant(P, top_k, torch.int32) + _, row_idx_slot = P.make_tmp_slot() + P.emit(FloorDivideNode(a=P.slot_to_tid(order_slot), b=P.slot_to_tid(top_k_const), + out=P.slot_to_tid(row_idx_slot))) + + _, x_gathered_slot = P.make_tmp_slot() + P.emit(TakeNode(x=P.slot_to_tid(x), out=P.slot_to_tid(x_gathered_slot), + index=P.slot_to_tid(row_idx_slot), axis=0)) + + _, x_input_slot = P.make_tmp_slot() + P.emit(ExpandDimsNode(x=P.slot_to_tid(x_gathered_slot), out=P.slot_to_tid(x_input_slot), axis=-2)) + + one_const = emit_lifted_constant(P, 1, torch.int32) + + P.emit(IdCopyNode(x=P.slot_to_tid(x_input_slot), out=P.slot_to_tid(out_slots[0]))) + P.emit(IdCopyNode(x=P.slot_to_tid(idx_slot), out=P.slot_to_tid(out_slots[1]))) + P.emit(IdCopyNode(x=P.slot_to_tid(one_const), out=P.slot_to_tid(out_slots[2]))) + P.emit(IdCopyNode(x=P.slot_to_tid(inv_order_slot), out=P.slot_to_tid(out_slots[3]))) + + def emit_unsorted(): + # BroadcastTo/Reshape + ExpandDimsNode, ReshapeNode (flatten idx), + # a 0-element const for inv_order (sentinel; never read when sort + # flag == 0, so no dynamic IotaNode needed), a 0-d const "0" + top_k_iov = IntOrVid.from_literal(top_k) + _, x_rep_slot = P.make_tmp_slot() + P.emit(RepeatNode(x=P.slot_to_tid(x), out=P.slot_to_tid(x_rep_slot), + repeats=top_k_iov, axis=0)) + + _, x_input_slot = P.make_tmp_slot() + P.emit(ExpandDimsNode(x=P.slot_to_tid(x_rep_slot), out=P.slot_to_tid(x_input_slot), axis=-2)) + + zero_const = emit_lifted_constant(P, 0, torch.int32) + + # 0-element sentinel: never read on the unsorted path (see + # moe_scatter_outputs), so a plain static constant suffices + # (no dynamic IotaNode needed, per the issue's own note). + empty_inv_order = P.make_or_get_constant( + "_moe_inv_order_sentinel", torch.empty(0, dtype=torch.int32) + ) + + P.emit(IdCopyNode(x=P.slot_to_tid(x_input_slot), out=P.slot_to_tid(out_slots[0]))) + P.emit(IdCopyNode(x=P.slot_to_tid(expert_indices), out=P.slot_to_tid(out_slots[1]))) + P.emit(IdCopyNode(x=P.slot_to_tid(zero_const), out=P.slot_to_tid(out_slots[2]))) + P.emit(IdCopyNode(x=P.slot_to_tid(empty_inv_order), out=P.slot_to_tid(out_slots[3]))) + + # cond = (M - 1) // sort_cutoff: 0 (-> else/unsorted) for M <= sort_cutoff, + # >= 1 (-> then/sorted) for M > sort_cutoff. The IfNode rule is nonzero -> + # then. If M is a compile-time literal, both fold and emit_if_else picks + # one branch -- no IfNode emitted. + cond = emit_floordiv(P, emit_sub_int(P, m_iov, IntOrVid.from_literal(1)), + IntOrVid.from_literal(sort_cutoff)) + emit_if_else(P, cond, emit_sorted, emit_unsorted) + return out_slots + + +@REGISTRY.register(target=[torch.ops.mlx.moe_scatter_outputs.default]) +def _moe_scatter_outputs_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Issue #20554. down [N*top_k, 1, H] -> squeeze -> (gather if sorted) + -> reshape [N, top_k, H]. Read sort_experts (0-d tensor) to a Vid via + ItemIntNode, then emit_if_else on it -- then: TakeNode(down, + inv_order) then ReshapeNode; else: ReshapeNode only (skip the gather). + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + ItemIntNode, SqueezeNode, TakeNode, ReshapeNode, IntOrVid, + ) + + args = P.args(n) + down, sort_experts, inv_order = args[0], args[1], args[2] + top_k = args[3] # static int + + out_slot = P.make_or_get_slot(n) + + _, down_sq_slot = P.make_tmp_slot() + P.emit(SqueezeNode(x=P.slot_to_tid(down), out=P.slot_to_tid(down_sq_slot), dims=[-2])) + + _, cond_val_slot = P.make_tmp_value_slot() + P.emit(ItemIntNode(x=P.slot_to_tid(sort_experts), out=P.slot_to_vid(cond_val_slot))) + + n_top_k_iov = emit_shape(P, n.args[0], down_sq_slot, end_dim=1)[0] + n_iov = emit_floordiv(P, n_top_k_iov, IntOrVid.from_literal(top_k)) + + def emit_then(): # prefill: scatter back + _, gathered_slot = P.make_tmp_slot() + P.emit(TakeNode(x=P.slot_to_tid(down_sq_slot), out=P.slot_to_tid(gathered_slot), + index=P.slot_to_tid(inv_order), axis=0)) + P.emit(ReshapeNode(x=P.slot_to_tid(gathered_slot), out=P.slot_to_tid(out_slot), + shape=[n_iov, IntOrVid.from_literal(top_k), IntOrVid.from_literal(-1)])) + + def emit_else(): # decode: no scatter + P.emit(ReshapeNode(x=P.slot_to_tid(down_sq_slot), out=P.slot_to_tid(out_slot), + shape=[n_iov, IntOrVid.from_literal(top_k), IntOrVid.from_literal(-1)])) + + emit_if_else(P, P.to_int_or_vid(cond_val_slot), emit_then, emit_else) + return out_slot + + @REGISTRY.register( target=[torch.ops.aten.split.Tensor, torch.ops.aten.split_copy.Tensor] ) @@ -3740,10 +3685,10 @@ def _sample_handler(P: MLXProgramBuilder, n: Node) -> Slot: skipping the sampling chain (so 0 is exact, not the small-epsilon approx). """ args = P.args(n) - require_args(args, 4, 5, "mlx.sample") + require_args(args, 3, 4, "mlx.sample") require_kwargs(P.kwargs(n), set(), "mlx.sample") - logits, temperature, top_k, top_p = args[0], args[1], args[2], args[3] - seed = args[4] if len(args) > 4 and args[4] is not None else None + logits, temperature, top_p = args[0], args[1], args[2] + seed = args[3] if len(args) > 3 and args[3] is not None else None temp_dt = n.args[1].meta["val"].dtype out = P.make_or_get_slot(n) @@ -3824,82 +3769,8 @@ def emit_sample(): ) ) scaled = logits_f - neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32) - - # Top-k first, on scaled logits. Clip k to vocab size so the default - # max-int sentinel selects every token. - vocab_size = int(n.args[0].meta["val"].shape[-1]) - vocab = emit_lifted_constant(P, vocab_size, torch.int64) - _, clipped_top_k = P.make_tmp_slot() - P.emit( - MinimumNode( - a=P.slot_to_tid(top_k), - b=P.slot_to_tid(vocab), - out=P.slot_to_tid(clipped_top_k), - ) - ) - _, top_k_val = P.make_tmp_value_slot() - P.emit( - ItemIntNode(x=P.slot_to_tid(clipped_top_k), out=P.slot_to_vid(top_k_val)) - ) - _, top_k_index = P.make_tmp_value_slot() - P.emit( - SubtractIntNode( - a=P.to_int_or_vid(top_k_val), - b=IntOrVid.from_literal(1), - out=P.slot_to_vid(top_k_index), - ) - ) - _, sorted_scaled = P.make_tmp_slot() - P.emit(NegNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(sorted_scaled))) - P.emit( - SortNode( - x=P.slot_to_tid(sorted_scaled), - out=P.slot_to_tid(sorted_scaled), - axis=-1, - ) - ) - P.emit( - NegNode(x=P.slot_to_tid(sorted_scaled), out=P.slot_to_tid(sorted_scaled)) - ) - _, top_k_thresh = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(sorted_scaled), - index=P.to_int_or_vid_or_tid(top_k_index), - out=P.slot_to_tid(top_k_thresh), - axis=-1, - ) - ) - P.emit( - ExpandDimsNode( - x=P.slot_to_tid(top_k_thresh), - out=P.slot_to_tid(top_k_thresh), - axis=-1, - ) - ) - _, drop_k = P.make_tmp_slot() - P.emit( - LessNode( - a=P.slot_to_tid(scaled), - b=P.slot_to_tid(top_k_thresh), - out=P.slot_to_tid(drop_k), - ) - ) - _, top_k_scaled = P.make_tmp_slot() - P.emit( - WhereNode( - condition=P.slot_to_tid(drop_k), - x=P.slot_to_tid(neg_inf), - y=P.slot_to_tid(scaled), - out=P.slot_to_tid(top_k_scaled), - ) - ) - scaled = top_k_scaled - - # Top-p nucleus mask on probabilities renormalized over the top-k set. - # SortNode is ascending-only, so sort -probs for descending. + # top-p nucleus mask; SortNode is ascending-only, so sort -probs for descending. # probs is read twice (neg_p below and the drop comparison), keep separate. _, probs = P.make_tmp_slot() P.emit(SoftmaxNode(x=P.slot_to_tid(scaled), out=P.slot_to_tid(probs), axis=-1)) @@ -3960,6 +3831,7 @@ def emit_sample(): out=P.slot_to_tid(drop), ) ) + neg_inf = emit_lifted_constant(P, float("-inf"), torch.float32) # masked = where(drop, -inf, scaled); then add gumbel noise in place. _, masked = P.make_tmp_slot() P.emit( @@ -4726,36 +4598,3 @@ def emit_reverse(in_slot, out_slot): ) return output_slots - - -# --------------------------------------------------------------------------- -# In-place variants for ops with bespoke functional handlers (clamp, pow, -# activations). Each reuses its functional handler via a passthrough that -# aliases out == self when self is a dead temp (see -# _make_inplace_passthrough_handler). Registered last, after every functional -# handler above is defined. REINPLACEABLE_EXTRA_BASE_OVERLOADS feeds passes.py. -# --------------------------------------------------------------------------- -REINPLACEABLE_EXTRA_BASE_OVERLOADS: List[Tuple[str, str]] = [] -for _base, _overload in ( - ("clamp", "default"), - ("clamp", "Tensor"), - ("gelu", "default"), - ("relu", "default"), - ("leaky_relu", "default"), - ("hardtanh", "default"), - ("pow", "Tensor_Scalar"), - ("pow", "Tensor_Tensor"), -): - _func_aten = getattr(getattr(torch.ops.aten, _base), _overload, None) - if _func_aten is None: - continue - _func_handler = REGISTRY._handlers.get(_func_aten) - if _func_handler is None: - continue - _ip_edge = _derive_edge_inplace_overload(_func_aten) - if _ip_edge is None: - continue - REGISTRY.register(target=[_ip_edge._op])( - _make_inplace_passthrough_handler(_func_handler) - ) - REINPLACEABLE_EXTRA_BASE_OVERLOADS.append((_base, _overload)) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 3c3c2c323a8..6090ad1da6a 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -872,7 +872,13 @@ exec_gather_mm(const GatherMmNode& n, ExecutionState& st, StreamOrDevice s) { rhs_idx = st.const_tensor_ref(*n.rhs_indices); } - array Y = gather_mm(A, B, lhs_idx, rhs_idx, n.sorted_indices, s); + array Y = gather_mm( + A, + B, + lhs_idx, + rhs_idx, + resolve_int(n.sorted_indices, st) != 0, + s); st.set_tensor(n.out, std::move(Y)); } @@ -906,7 +912,7 @@ exec_gather_qmm(const GatherQmmNode& n, ExecutionState& st, StreamOrDevice s) { n.group_size, n.bits, n.mode, - n.sorted_indices, + resolve_int(n.sorted_indices, st) != 0, s); st.set_tensor(n.out, std::move(Y)); } diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 281199a8002..1083839827c 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -943,27 +943,27 @@ table MedianNode { } table GatherMmNode { - a: Tid (required); // Input activations - b: Tid (required); // Weight matrix [E, out, in] or similar + a: Tid (required); + b: Tid (required); out: Tid (required); + sorted_indices: IntOrVid (required); // issue #20554: bool -> IntOrVid (cf. PartitionNode.kth:683) lhs_indices: Tid; // optional - LHS gather indices rhs_indices: Tid; // optional - RHS gather indices (expert selection) - sorted_indices: bool = false; } table GatherQmmNode { - x: Tid (required); // Input activations - w: Tid (required); // Quantized weight matrix [E, out, in_packed] - scales: Tid (required); // Quantization scales [E, out, in//gs] + x: Tid (required); + w: Tid (required); + scales: Tid (required); out: Tid (required); - mode: string (required); // "affine", "fp", etc. + mode: string (required); + sorted_indices: IntOrVid (required); // issue #20554: bool -> IntOrVid (cf. PartitionNode.kth:683) biases: Tid; // optional - for affine mode lhs_indices: Tid; // optional - LHS gather indices rhs_indices: Tid; // optional - RHS gather indices (expert selection) transpose: bool = true; group_size: int32; bits: int32; - sorted_indices: bool = false; } table ScanNode { diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index b99bd208d04..e33e3789778 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -27,7 +27,6 @@ import os from typing import Callable, Dict, List, Optional, Tuple -import executorch.exir as exir import torch import torch.nn as nn @@ -114,62 +113,6 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (x, y) -class ReinplaceChainModel(nn.Module): - """Elementwise chain the reinplace pass converts to in-place ops. - - Mixes a pure unary op (exp), activations (sigmoid/relu/clamp/gelu), and - binary ops (add/mul/sub) so the chain exercises the unary, activation, and - binary in-place handlers. Every op after the first sigmoid consumes a - single-use temp, so all become in-place (sigmoid_/add_/relu_/mul_/clamp_/ - exp_/gelu_/sub_) and run on one rolling buffer; the terminal neg writes the - graph output. Inputs are kept NaN/Inf-free: sigmoid -> bounded, clamp to - [-2, 2] before exp so exp stays in [e^-2, e^2]. - """ - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - s = torch.sigmoid(x) # reads input -> fresh temp - s = s + y # add_ - s = torch.relu(s) # relu_ (activation) - s = s * y # mul_ - s = torch.clamp(s, -2.0, 2.0) # clamp_ (activation), bounds exp below - s = torch.exp(s) # exp_ (pure unary) - s = torch.nn.functional.gelu(s) # gelu_ (activation) - s = s - y # sub_ - return torch.neg(s) # terminal output (not in-place) - - -@register_test -class ReinplaceChainTest(OpTestCase): - """On-device numeric check that reinplaced (out==in) ops are correct. - - Lowers with get_default_passes() so the MLXReinplacePass + in-place handlers - (out == in buffer donation) run through the actual MLX runtime. The - build-level aliasing is unit-tested in test_passes.py; only on-device - execution catches a read-after-overwrite bug from buffer reuse. - """ - - name = "reinplace_chain" - rtol = 1e-4 - atol = 1e-4 - - def create_model(self) -> nn.Module: - return ReinplaceChainModel() - - def create_inputs(self) -> Tuple[torch.Tensor, ...]: - return (torch.randn(2, 16, 64), torch.randn(2, 16, 64)) - - def get_edge_compile_config(self) -> Optional[exir.EdgeCompileConfig]: - # Reinplace introduces non-core-ATen in-place ops (add_, sigmoid_, ...), - # so disable the strict edge verifier — matching the production export - # path (which also runs get_default_passes with this config). - return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True) - - def get_transform_passes(self) -> Optional[list]: - from executorch.backends.mlx.passes import get_default_passes - - return get_default_passes() - - class SubModel(nn.Module): """Model that performs element-wise subtraction, optionally with alpha.""" @@ -6599,12 +6542,28 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: class GatherMmModel(nn.Module): """Model using mlx::gather_mm for expert selection + matmul.""" - def __init__(self, num_experts: int, in_features: int, out_features: int): + def __init__( + self, + num_experts: int, + in_features: int, + out_features: int, + sorted_indices: bool = False, + ): super().__init__() self.register_buffer( "weight", torch.randn(num_experts, out_features, in_features), ) + # issue #20554: sorted_indices is now Optional[Tensor] (0-d int32) + # rather than a bool. Store as buffer so it is part of the exported + # graph and exercises the new IntOrVid runtime path in the handler. + if sorted_indices: + self.register_buffer( + "sorted_flag", + torch.ones((), dtype=torch.int32), + ) + else: + self.sorted_flag = None def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: import executorch.backends.mlx.custom_ops as _ # noqa @@ -6613,13 +6572,20 @@ def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: # Transpose weight from [E, out, in] to [E, in, out] # gather_mm returns [N, 1, out], squeeze dim -2 return torch.ops.mlx.gather_mm( - x.unsqueeze(-2), self.weight.transpose(-1, -2), rhs_indices=indices + x.unsqueeze(-2), + self.weight.transpose(-1, -2), + rhs_indices=indices, + sorted_indices=self.sorted_flag, ).squeeze(-2) @register_test class GatherMmTest(OpTestCase): - """Test case for mlx::gather_mm.""" + """Test case for mlx::gather_mm. + + issue #20554: added sorted=True config to exercise the new + Optional[Tensor] -> IntOrVid runtime path in _gather_mm_handler. + """ name = "gather_mm" rtol = 1e-4 @@ -6632,16 +6598,20 @@ def __init__( out_features: int = 128, batch_size: int = 2, dtype: torch.dtype = torch.float32, + sorted_indices: bool = False, ): self.num_experts = num_experts self.in_features = in_features self.out_features = out_features self.batch_size = batch_size self.dtype = dtype + self.sorted_indices = sorted_indices parts = ["gather_mm", f"e{num_experts}", f"i{in_features}", f"o{out_features}"] if dtype != torch.float32: parts.append(str(dtype).split(".")[-1]) + if sorted_indices: + parts.append("sorted") self.name = "_".join(parts) @classmethod @@ -6651,10 +6621,18 @@ def get_test_configs(cls) -> List["GatherMmTest"]: cls(num_experts=8, in_features=128, out_features=256), cls(dtype=torch.bfloat16), cls(batch_size=1), + # issue #20554: exercise sorted_indices=Tensor (IntOrVid runtime path) + cls(sorted_indices=True), + cls(sorted_indices=True, dtype=torch.bfloat16), ] def create_model(self) -> nn.Module: - model = GatherMmModel(self.num_experts, self.in_features, self.out_features) + model = GatherMmModel( + self.num_experts, + self.in_features, + self.out_features, + sorted_indices=self.sorted_indices, + ) return model.to(self.dtype) def create_inputs(self) -> Tuple[torch.Tensor, ...]: @@ -6676,10 +6654,16 @@ def __init__( in_features: int, out_features: int, group_size: int = 32, + sorted_indices: bool = False, ): super().__init__() self.out_features = out_features self.group_size = group_size + # issue #20554: same pattern as GatherMmModel + if sorted_indices: + self.register_buffer("sorted_flag", torch.ones((), dtype=torch.int32)) + else: + self.sorted_flag = None # Create per-expert nn.Linear, quantize, extract inner tensors from executorch.backends.mlx.llm.quantization import quantize_model_ @@ -6720,12 +6704,17 @@ def forward(self, x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: biases=self.zero_point, rhs_indices=indices, group_size=self.group_size, + sorted_indices=self.sorted_flag, # issue #20554: Optional[Tensor] ).squeeze(-2) @register_test class GatherQmmTest(OpTestCase): - """Test case for mlx::gather_qmm.""" + """Test case for mlx::gather_qmm. + + issue #20554: added sorted=True config to exercise the new + Optional[Tensor] -> IntOrVid runtime path in _gather_qmm_handler. + """ name = "gather_qmm" rtol = 0.1 @@ -6739,6 +6728,7 @@ def __init__( batch_size: int = 2, group_size: int = 32, dtype: torch.dtype = torch.float32, + sorted_indices: bool = False, ): self.num_experts = num_experts self.in_features = in_features @@ -6746,6 +6736,7 @@ def __init__( self.batch_size = batch_size self.group_size = group_size self.dtype = dtype + self.sorted_indices = sorted_indices parts = [ "gather_qmm", @@ -6756,6 +6747,8 @@ def __init__( ] if dtype != torch.float32: parts.append(str(dtype).split(".")[-1]) + if sorted_indices: + parts.append("sorted") self.name = "_".join(parts) @classmethod @@ -6765,6 +6758,9 @@ def get_test_configs(cls) -> List["GatherQmmTest"]: cls(num_experts=8, in_features=128, out_features=256), cls(dtype=torch.bfloat16), cls(batch_size=1), + # issue #20554: exercise sorted_indices=Tensor (IntOrVid runtime path) + cls(sorted_indices=True), + cls(sorted_indices=True, dtype=torch.bfloat16), ] def get_edge_compile_config(self): @@ -6778,6 +6774,7 @@ def create_model(self) -> nn.Module: self.in_features, self.out_features, self.group_size, + sorted_indices=self.sorted_indices, ) return model.to(self.dtype) @@ -7733,17 +7730,6 @@ def forward(self, logits, temperature, seed, top_p): return self.head(logits, temperature=temperature, seed=seed, top_p=top_p) -class TopKSampleModel(nn.Module): - """SamplingHead with temperature, seed, and top_k as runtime inputs.""" - - def __init__(self): - super().__init__() - self.head = SamplingHead(_LogitsPassthrough()) - - def forward(self, logits, temperature, seed, top_k): - return self.head(logits, temperature=temperature, seed=seed, top_k=top_k) - - @register_test class SampleSeededTest(OpTestCase): """Seeded sample lowers to one MLX segment; seed threads in via ItemIntNode.""" @@ -7754,14 +7740,12 @@ class SampleSeededTest(OpTestCase): "IfNode": 1, # temperature==0 greedy branch "RandomBitsNode": 1, "ArgmaxNode": 2, # sampling branch + greedy branch - "ItemIntNode": 3, # seed + top_k + temperature>0 condition + "ItemIntNode": 2, # seed + temperature>0 condition "SoftmaxNode": 1, # top-p nucleus chain - "SortNode": 2, # top-k threshold + top-p nucleus chain + "SortNode": 1, "CumsumNode": 1, "MinNode": 1, - "TakeNode": 2, # last-token slice + top-k threshold gather - "ExpandDimsNode": 1, - "WhereNode": 3, + "WhereNode": 2, } def create_model(self) -> nn.Module: @@ -7785,7 +7769,7 @@ class SampleUnseededTest(OpTestCase): "IfNode": 1, "RandomBitsNode": 1, "ArgmaxNode": 2, - "ItemIntNode": 2, # top_k + temperature>0 condition only (no seed) + "ItemIntNode": 1, # temperature>0 condition only (no seed) "SoftmaxNode": 1, # top-p nucleus chain (top_p defaults to 1.0) } @@ -7806,14 +7790,12 @@ class SampleTopPTest(OpTestCase): "IfNode": 1, "RandomBitsNode": 1, "ArgmaxNode": 2, - "ItemIntNode": 3, + "ItemIntNode": 2, "SoftmaxNode": 1, - "SortNode": 2, + "SortNode": 1, "CumsumNode": 1, "MinNode": 1, - "TakeNode": 2, # last-token slice + top-k threshold gather - "ExpandDimsNode": 1, - "WhereNode": 3, + "WhereNode": 2, } def create_model(self) -> nn.Module: @@ -7828,39 +7810,6 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: ) -@register_test -class SampleTopKTest(OpTestCase): - """Top-k sample emits the threshold before the top-p nucleus chain.""" - - name = "sample_top_k" - skip_comparison = True # sampling RNG is not host/device bit-identical - expected_node_counts = { - "IfNode": 1, - "RandomBitsNode": 1, - "ArgmaxNode": 2, - "ItemIntNode": 3, # seed + top_k + temperature>0 condition - "SoftmaxNode": 1, - "SortNode": 2, - "CumsumNode": 1, - "MinNode": 1, - "TakeNode": 2, # last-token slice + top-k threshold gather - "ExpandDimsNode": 1, - "LogicalOrNode": 0, - "WhereNode": 3, - } - - def create_model(self) -> nn.Module: - return TopKSampleModel() - - def create_inputs(self) -> Tuple[torch.Tensor, ...]: - return ( - torch.randn(1, 4, 256), - torch.tensor(0.8), - torch.tensor(0, dtype=torch.int64), - torch.tensor(2, dtype=torch.int64), - ) - - @register_test class SampleGreedyTest(OpTestCase): """Greedy argmax(logits) is bit-exact host/device, so verify the token with the @@ -7903,3 +7852,191 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.tensor(self.temperature), torch.tensor(0, dtype=torch.int64), ) + + +# --------------------------------------------------------------------------- +# Issue #20554: moe_gather_inputs / moe_scatter_outputs +# --------------------------------------------------------------------------- + + +class MoeGatherInputsModel(nn.Module): + """Wraps moe_gather_inputs to make it exportable as a single-output model. + Returns only x_input (the first of the four outputs) so OpTestCase can + compare it against the eager reference using its standard allclose check. + The remaining outputs (idx, sort_experts, inv_order) are validated in + MoeScatterOutputsModel below via the round-trip prefill test. + """ + + def __init__(self, top_k: int = 2, sort_cutoff: int = 1): + super().__init__() + self.top_k = top_k + self.sort_cutoff = sort_cutoff + + def forward( + self, x: torch.Tensor, expert_indices: torch.Tensor + ) -> torch.Tensor: + import executorch.backends.mlx.custom_ops as _ # noqa: F401 + + x_input, _, _, _ = torch.ops.mlx.moe_gather_inputs( + x, expert_indices, self.top_k, self.sort_cutoff + ) + return x_input + + +@register_test +class MoeGatherInputsTest(OpTestCase): + """Test case for mlx::moe_gather_inputs (issue #20554). + + Covers both N=1 (decode, unsorted) and N>sort_cutoff (prefill, sorted) + paths via separate configs, matching get_test_configs() convention. + Node counts verified by validate_moe_20554.py. + """ + + name = "moe_gather_inputs" + rtol = 1e-5 + atol = 1e-5 + + def __init__( + self, + batch_size: int = 4, + hidden_size: int = 32, + num_experts: int = 8, + top_k: int = 2, + sort_cutoff: int = 1, + tag: str = "", + ): + self.batch_size = batch_size + self.hidden_size = hidden_size + self.num_experts = num_experts + self.top_k = top_k + self.sort_cutoff = sort_cutoff + + parts = ["moe_gather_inputs", f"N{batch_size}", f"E{num_experts}", f"k{top_k}"] + if tag: + parts.append(tag) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MoeGatherInputsTest"]: + return [ + cls(batch_size=4, tag="prefill"), # N > sort_cutoff -> sorted path + cls(batch_size=1, tag="decode"), # N <= sort_cutoff -> unsorted path + cls(batch_size=4, num_experts=16, top_k=4, tag="top4"), + ] + + def get_expected_node_counts(self) -> Optional[Dict[str, int]]: + if self.batch_size > self.sort_cutoff: + # sorted path: 2x ArgsortNode, no RepeatNode, no IfNode + return { + "ArgsortNode": 2, + "TakeNode": 2, + "FloorDivideNode": 1, + "ExpandDimsNode": 1, + "IdCopyNode": 5, + "RepeatNode": 0, + "IfNode": 0, + } + else: + # unsorted path: RepeatNode, no ArgsortNode, no IfNode + return { + "ArgsortNode": 0, + "RepeatNode": 1, + "ExpandDimsNode": 1, + "IdCopyNode": 5, + "IfNode": 0, + } + + def create_model(self) -> nn.Module: + return MoeGatherInputsModel(top_k=self.top_k, sort_cutoff=self.sort_cutoff) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.hidden_size) + expert_indices = torch.randint(0, self.num_experts, (self.batch_size, self.top_k)) + return (x, expert_indices) + + +class MoeScatterOutputsModel(nn.Module): + """Round-trip: moe_gather_inputs -> identity down_proj -> moe_scatter_outputs. + Returns the final [N, top_k, hidden] tensor so OpTestCase can verify + the full gather/sort/scatter pipeline end-to-end. + """ + + def __init__(self, top_k: int = 2, sort_cutoff: int = 1, hidden_out: int = 16): + super().__init__() + self.top_k = top_k + self.sort_cutoff = sort_cutoff + self.hidden_out = hidden_out + + def forward( + self, x: torch.Tensor, expert_indices: torch.Tensor + ) -> torch.Tensor: + import executorch.backends.mlx.custom_ops as _ # noqa: F401 + + x_input, idx, sort_experts, inv_order = torch.ops.mlx.moe_gather_inputs( + x, expert_indices, self.top_k, self.sort_cutoff + ) + # Simulate a down_proj output: [N*top_k, 1, hidden_out] + NK = x_input.shape[0] + down = x_input[..., : self.hidden_out].contiguous() # cheap slice as proxy + if down.shape[-1] < self.hidden_out: + down = down.repeat(1, 1, self.hidden_out // down.shape[-1] + 1)[ + ..., : self.hidden_out + ] + down = down.view(NK, 1, self.hidden_out) + return torch.ops.mlx.moe_scatter_outputs(down, sort_experts, inv_order, self.top_k) + + +@register_test +class MoeScatterOutputsTest(OpTestCase): + """Test case for mlx::moe_scatter_outputs (issue #20554). + + Validates the round-trip shape and that the unsorted (decode) path + produces a result consistent with the sorted (prefill) path. + """ + + name = "moe_scatter_outputs" + rtol = 1e-4 + atol = 1e-4 + + def __init__( + self, + batch_size: int = 4, + hidden_size: int = 32, + hidden_out: int = 16, + num_experts: int = 8, + top_k: int = 2, + sort_cutoff: int = 1, + tag: str = "", + ): + self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden_out = hidden_out + self.num_experts = num_experts + self.top_k = top_k + self.sort_cutoff = sort_cutoff + + parts = ["moe_scatter_outputs", f"N{batch_size}", f"E{num_experts}", f"k{top_k}"] + if tag: + parts.append(tag) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["MoeScatterOutputsTest"]: + return [ + cls(batch_size=4, tag="prefill"), + cls(batch_size=1, tag="decode"), + cls(batch_size=4, num_experts=16, top_k=4, hidden_out=32, tag="top4"), + ] + + def create_model(self) -> nn.Module: + return MoeScatterOutputsModel( + top_k=self.top_k, + sort_cutoff=self.sort_cutoff, + hidden_out=self.hidden_out, + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn(self.batch_size, self.hidden_size) + expert_indices = torch.randint(0, self.num_experts, (self.batch_size, self.top_k)) + return (x, expert_indices) + diff --git a/backends/mlx/test/validate_moe_20554.py b/backends/mlx/test/validate_moe_20554.py new file mode 100644 index 00000000000..ab485089e33 --- /dev/null +++ b/backends/mlx/test/validate_moe_20554.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +"""Local validation for issue #20554 (no MLX runner / no full pip install).""" + +from __future__ import annotations + +import os +import sys +import types +from pathlib import Path + +REPO = Path(__file__).resolve().parents[3] +SRC = REPO / "src" + + +def _bootstrap_executorch() -> None: + """Wire executorch.* to repo paths (Windows git symlinks are often broken).""" + import importlib.util + + sys.path.insert(0, str(SRC)) + sys.path.insert(0, str(REPO)) + + import executorch + + def _load_pkg(name: str) -> None: + entry = SRC / "executorch" / name + if entry.is_file(): + target = entry.read_text(encoding="utf-8").strip().replace("/", os.sep) + real_root = (entry.parent / target).resolve() + else: + real_root = (REPO / name).resolve() + init_py = real_root / "__init__.py" + if not init_py.exists(): + return + full_name = f"executorch.{name}" + spec = importlib.util.spec_from_file_location( + full_name, + init_py, + submodule_search_locations=[str(real_root)], + ) + if spec is None or spec.loader is None: + return + mod = importlib.util.module_from_spec(spec) + sys.modules[full_name] = mod + setattr(executorch, name, mod) + spec.loader.exec_module(mod) + + def _load_subpkg(parent: str, name: str, root: Path) -> None: + init_py = root / "__init__.py" + if not init_py.exists(): + return + full_name = f"{parent}.{name}" + spec = importlib.util.spec_from_file_location( + full_name, + init_py, + submodule_search_locations=[str(root)], + ) + if spec is None or spec.loader is None: + return + mod = importlib.util.module_from_spec(spec) + sys.modules[full_name] = mod + parent_mod = sys.modules[parent] + setattr(parent_mod, name, mod) + spec.loader.exec_module(mod) + + # exir.tracer needs executorch.extension.pytree (pybindings optional). + ext = types.ModuleType("executorch.extension") + sys.modules["executorch.extension"] = ext + executorch.extension = ext + _load_subpkg("executorch.extension", "pytree", REPO / "extension" / "pytree") + + _load_pkg("exir") + _load_pkg("backends") + + +def _load_custom_ops(): + import importlib.util + + spec = importlib.util.spec_from_file_location( + "mlx_custom_ops", + REPO / "backends" / "mlx" / "custom_ops.py", + ) + assert spec and spec.loader + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def test_eager_moe_ops() -> None: + import torch + + _load_custom_ops() + + torch.manual_seed(0) + top_k = 2 + D, H = 32, 16 + + # --- decode path (N=1, sort_cutoff=1) --- + x = torch.randn(1, D) + expert_indices = torch.tensor([[1, 3]], dtype=torch.int64) + x_in, idx, sort_flag, inv = torch.ops.mlx.moe_gather_inputs( + x, expert_indices, top_k, 1 + ) + assert x_in.shape == (2, 1, D) + assert idx.shape == (2,) + assert sort_flag.item() == 0 + assert inv.numel() == 0 + + down = torch.randn(2, 1, H) + out = torch.ops.mlx.moe_scatter_outputs(down, sort_flag, inv, top_k) + assert out.shape == (1, top_k, H) + assert torch.allclose(out, down.squeeze(-2).reshape(1, top_k, H).clone()) + + # --- prefill path (N=4, sort_cutoff=1) --- + x = torch.randn(4, D) + expert_indices = torch.tensor([[2, 0], [1, 3], [0, 2], [3, 1]], dtype=torch.int64) + x_in, idx, sort_flag, inv = torch.ops.mlx.moe_gather_inputs( + x, expert_indices, top_k, 1 + ) + assert x_in.shape == (8, 1, D) + assert sort_flag.item() == 1 + assert inv.shape == (8,) + + # Reference sorted path (mirror eager in custom_ops) + flat = expert_indices.flatten() + order = flat.argsort().to(torch.int32) + inv_ref = order.argsort().to(torch.int32) + idx_ref = flat[order].to(torch.int32) + x_ref = x[(order // top_k).to(torch.int64)].unsqueeze(-2) + assert torch.equal(idx, idx_ref) + assert torch.equal(inv, inv_ref) + assert torch.allclose(x_in, x_ref) + + down = torch.randn(8, 1, H) + out = torch.ops.mlx.moe_scatter_outputs(down, sort_flag, inv, top_k) + down_sq = down.squeeze(-2) + ref = down_sq[inv_ref].reshape(4, top_k, H) + assert torch.allclose(out, ref) + + print("PASS: eager moe_gather_inputs / moe_scatter_outputs") + + +def test_opcheck() -> None: + import torch + from torch.library import opcheck + + _load_custom_ops() + + x = torch.randn(4, 32) + expert_indices = torch.randint(0, 8, (4, 2)) + opcheck(torch.ops.mlx.moe_gather_inputs, (x, expert_indices, 2, 1)) + down = torch.randn(8, 1, 16) + sort_experts = torch.tensor(1, dtype=torch.int32) + inv_order = torch.arange(8, dtype=torch.int32) + opcheck( + torch.ops.mlx.moe_scatter_outputs, + (down, sort_experts, inv_order, 2), + ) + print("PASS: torch.library.opcheck") + + +def test_export_traces_moe_ops() -> None: + import torch + import torch.nn as nn + from torch.export import export + + _load_custom_ops() + + class MoeModel(nn.Module): + def forward(self, x, expert_indices): + gathered = torch.ops.mlx.moe_gather_inputs(x, expert_indices, 2, 1) + return torch.ops.mlx.moe_scatter_outputs( + torch.randn(gathered[0].shape[0], 1, 16), + gathered[2], + gathered[3], + 2, + ) + + ep = export( + MoeModel(), + (torch.randn(4, 32), torch.randint(0, 4, (4, 2))), + ) + targets = { + str(n.target) + for n in ep.graph.nodes + if n.op == "call_function" + } + assert any("moe_gather_inputs" in t for t in targets), targets + assert any("moe_scatter_outputs" in t for t in targets), targets + print("PASS: torch.export traces moe ops as leaf nodes") + + +def _count_mlx_nodes(mlx_graph) -> dict[str, int]: + from collections import Counter + + return dict( + Counter( + type(instr.op).__name__ + for chain in mlx_graph.instruction_chains + for instr in chain.instructions + ) + ) + + +def test_export_lowering_node_counts() -> None: + import torch + import torch.nn as nn + from torch.export import export + + from executorch.backends.mlx import custom_ops # noqa: F401 + from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder + from executorch.exir import EdgeCompileConfig, to_edge + + class MoeGather(nn.Module): + def forward(self, x, expert_indices): + return torch.ops.mlx.moe_gather_inputs(x, expert_indices, 2, 1)[0] + + cfg = EdgeCompileConfig(_check_ir_validity=False) + + ep = export(MoeGather(), (torch.randn(4, 32), torch.randint(0, 4, (4, 2)))) + prefill = _count_mlx_nodes( + MLXProgramBuilder(to_edge(ep, compile_config=cfg).exported_program()).build() + ) + assert prefill.get("IfNode", 0) == 0, prefill + assert prefill.get("ArgsortNode", 0) == 2, prefill + assert prefill.get("RepeatNode", 0) == 0, prefill + print(f"PASS: MLX lowering (prefill): {prefill}") + + ep1 = export(MoeGather(), (torch.randn(1, 32), torch.randint(0, 4, (1, 2)))) + decode = _count_mlx_nodes( + MLXProgramBuilder(to_edge(ep1, compile_config=cfg).exported_program()).build() + ) + assert decode.get("IfNode", 0) == 0, decode + assert decode.get("ArgsortNode", 0) == 0, decode + assert decode.get("RepeatNode", 0) == 1, decode + print(f"PASS: MLX lowering (decode): {decode}") + + +def test_switch_mlp_forward() -> None: + import torch + import torch.nn as nn + from executorch.backends.mlx import custom_ops # noqa: F401 + from executorch.backends.mlx.llm.switch import SwitchMLP, pack_all_switch_linears + + mlp = SwitchMLP(32, 64, num_experts=4, sort_cutoff=1) + for mod in mlp.modules(): + if hasattr(mod, "experts"): + for e in mod.experts: + nn.init.uniform_(e.weight, -0.1, 0.1) + pack_all_switch_linears(mlp) + + x = torch.randn(4, 32) + weights = torch.softmax(torch.randn(4, 2), dim=-1) + indices = torch.randint(0, 4, (4, 2)) + + out_prefill = mlp(x, weights, indices, top_k=2) + assert out_prefill.shape == (4, 32) + + x1 = torch.randn(1, 32) + w1 = torch.softmax(torch.randn(1, 2), dim=-1) + i1 = torch.randint(0, 4, (1, 2)) + out_decode = mlp(x1, w1, i1, top_k=2) + assert out_decode.shape == (1, 32) + print("PASS: SwitchMLP forward (prefill + decode)") + + +def main() -> int: + test_eager_moe_ops() + test_opcheck() + test_export_traces_moe_ops() + _bootstrap_executorch() + test_switch_mlp_forward() + try: + test_export_lowering_node_counts() + except Exception as e: + print(f"FAIL: MLX lowering: {e}") + return 1 + print("\nAll validations passed.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index c8179b9ccb0..750a018a792 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -53,7 +53,7 @@ def _prepare_and_quantize_mlx(model, config, args): model, model_dtype=torch.bfloat16, config=config, - sort_experts=True, + sort_cutoff=1, fuse_gate_up=False, ) if args.qlinear or args.qembedding: diff --git a/examples/models/qwen3_5_moe/mlx_source_transformations.py b/examples/models/qwen3_5_moe/mlx_source_transformations.py index 3c460fc9c54..0921d042d17 100644 --- a/examples/models/qwen3_5_moe/mlx_source_transformations.py +++ b/examples/models/qwen3_5_moe/mlx_source_transformations.py @@ -70,7 +70,6 @@ def _sparse_moe_forward(self, x): expert_weights, expert_indices, self.top_k, - sort_experts=getattr(self, "_sort_experts", False), ) shared_out = self.shared_expert(x_flat) @@ -215,7 +214,7 @@ def _exportable_gated_delta_net_forward(self, x, input_pos): return self.out_proj(output) -def _swap_moe_experts(model, fuse_gate_up): +def _swap_moe_experts(model, fuse_gate_up, sort_cutoff=1): """FusedMoEExperts → SwitchMLP.""" from executorch.backends.mlx.llm.switch import SwitchMLP @@ -229,6 +228,7 @@ def _swap_moe_experts(model, fuse_gate_up): module.intermediate_size, module.num_experts, fuse_gate_up=fuse_gate_up, + sort_cutoff=sort_cutoff, ) switch_mlp.to(dtype=module.w1_weight.dtype) @@ -321,12 +321,11 @@ def _swap_rms_norm(model): return count -def _swap_sparse_moe(model, sort_experts): +def _swap_sparse_moe(model): """SparseMoE → no .float() on expert_weights.""" count = 0 for _name, module in model.named_modules(): if isinstance(module, SparseMoE): - module._sort_experts = sort_experts module.forward = types.MethodType(_sparse_moe_forward, module) count += 1 return count @@ -336,7 +335,7 @@ def mlx_source_transformations( model, model_dtype=torch.bfloat16, config=None, - sort_experts=False, + sort_cutoff=1, fuse_gate_up=False, ): """Replace all Triton-dependent modules with MLX-compatible equivalents. @@ -353,15 +352,16 @@ def mlx_source_transformations( model: The Qwen 3.5 MoE model to transform. model_dtype: Target dtype for the model (default: bf16). config: Model config (Qwen35MoEConfig). - sort_experts: Sort tokens by expert index for coalesced memory access. + sort_cutoff: Token-count threshold for runtime MoE expert sorting. + Sort when M > sort_cutoff (default 1 = sort on prefill only). fuse_gate_up: Fuse gate+up into single SwitchLinear. """ - count_moe = _swap_moe_experts(model, fuse_gate_up) + count_moe = _swap_moe_experts(model, fuse_gate_up, sort_cutoff) count_gdn = _swap_gated_delta_net(model, model_dtype) count_attn = _swap_full_attention(model, config) count_kv = _swap_kv_cache(model, model_dtype) count_norm = _swap_rms_norm(model) - count_moe_fwd = _swap_sparse_moe(model, sort_experts) + count_moe_fwd = _swap_sparse_moe(model) logger.info(f"Replaced {count_moe} FusedMoEExperts → SwitchMLP") logger.info(f"Replaced {count_gdn} GatedDeltaNet → exportable PyTorch forward")