Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions backends/mlx/builder/op_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,28 @@ def emit_ceil_div(
return P.to_int_or_vid(out_slot)


def emit_floordiv(
P: "MLXProgramBuilder",
a: "IntOrVid",
b: "IntOrVid",
) -> "IntOrVid":
"""Emit ``a // b`` (floor division), folding when both operands are
static (issue #20554). Used for ``cond = (M - 1) // sort_cutoff``:
0 when M <= sort_cutoff (unsorted), >= 1 when M > sort_cutoff (sorted).
"""
from executorch.backends.mlx.serialization.mlx_graph_schema import (
FloorDivideIntNode,
IntOrVid,
)

if not a.is_vid and not b.is_vid:
return IntOrVid.from_literal(a.literal // b.literal)

_, out_slot = P.make_tmp_value_slot()
P.emit(FloorDivideIntNode(a=a, b=b, out=P.slot_to_vid(out_slot)))
return P.to_int_or_vid(out_slot)


def emit_if_else(
P: "MLXProgramBuilder",
cond: "IntOrVid",
Expand Down
104 changes: 85 additions & 19 deletions backends/mlx/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
can execute efficiently but may not have direct PyTorch equivalents.
"""

from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -285,7 +285,7 @@ def gather_mm(
b: Tensor, # [E, K, N] or [..., K, N]
rhs_indices: Optional[Tensor] = None, # Expert selection indices
lhs_indices: Optional[Tensor] = None, # Optional LHS gather indices
sorted_indices: bool = False,
sorted_indices: Optional[Tensor] = None, # issue #20554: 0-d int; None/0 = unsorted
) -> Tensor:
"""
Gather matrix multiply — matches mlx::core::gather_mm semantics exactly.
Expand All @@ -295,6 +295,10 @@ def gather_mm(

For MoE: a=[N_tokens, 1, K], b=[E, K, out], rhs_indices=[N_tokens]
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.

sorted_indices is layout-only (correctness contract for the MLX
kernel at runtime); numerics are identical either way, so the eager
reference ignores it (issue #20554, section 2).
"""
if rhs_indices is not None:
b_sel = b[rhs_indices]
Expand All @@ -309,7 +313,7 @@ def gather_mm_fake(
b: Tensor,
rhs_indices: Optional[Tensor] = None,
lhs_indices: Optional[Tensor] = None,
sorted_indices: bool = False,
sorted_indices: Optional[Tensor] = None,
) -> Tensor:
# Matches MLX: output = indices.shape + [M, N]
# For simplicity, use matmul shape rules after gather
Expand All @@ -334,7 +338,7 @@ def gather_qmm(
group_size: int = 32,
bits: int = 4,
mode: str = "affine",
sorted_indices: bool = False,
sorted_indices: Optional[Tensor] = None, # issue #20554: 0-d int; None/0 = unsorted
) -> Tensor:
"""
Gather quantized matrix multiply — matches mlx::core::gather_qmm semantics.
Expand All @@ -343,6 +347,8 @@ def gather_qmm(

For MoE: x=[N_tokens, 1, K], w=[E, out, K_packed], rhs_indices=[N_tokens]
→ output=[N_tokens, 1, out]. Caller squeezes dim -2.

sorted_indices is layout-only; ignored here (see gather_mm docstring).
"""
# Eager fallback: gather, dequantize, matmul
if rhs_indices is not None:
Expand Down Expand Up @@ -381,7 +387,7 @@ def gather_qmm_fake(
group_size: int = 32,
bits: int = 4,
mode: str = "affine",
sorted_indices: bool = False,
sorted_indices: Optional[Tensor] = None,
) -> Tensor:
# Matches MLX: output = indices.shape + [M, N]
M = x.shape[-2]
Expand All @@ -397,19 +403,15 @@ def gather_qmm_fake(
def sample(
logits: Tensor,
temperature: Tensor,
top_k: Tensor,
top_p: Tensor,
seed: Optional[Tensor] = None,
) -> Tensor:
"""
Gumbel-max sampling from softmax(logits / temperature), with top-k and
top-p (nucleus) filtering.
Gumbel-max sampling from softmax(logits / temperature), with top-p (nucleus).
logits: [B, vocab]
temperature: scalar float tensor (runtime input). temperature <= 0 is
greedy: return argmax(logits) directly (matches the device,
which branches on temperature > 0).
top_k: scalar int tensor. It is clipped to the vocab size; using the
max int default keeps every token.
top_p: scalar float tensor in (0, 1]. top_p=1.0 keeps every
token, i.e. it is off.
seed: scalar int tensor or None
Expand All @@ -426,14 +428,6 @@ def sample(
return torch.argmax(logits, dim=-1)
# whole chain in fp32 to match the lowered graph (bf16 sums mis-rank ties).
scaled = logits.float() / temperature

k = min(int(top_k.item()), scaled.shape[-1])
s_scaled, _ = torch.sort(scaled, dim=-1, descending=True)
kth = s_scaled[..., k - 1 : k]
scaled = torch.where(scaled >= kth, scaled, scaled.new_tensor(float("-inf")))

# Apply top-p after top-k so the probabilities are renormalized over the
# top-k subset.
probs = torch.softmax(scaled, dim=-1)
s_probs, _ = torch.sort(probs, dim=-1, descending=True)
cum = torch.cumsum(s_probs, dim=-1)
Expand All @@ -452,5 +446,77 @@ def sample(


@torch.library.register_fake("mlx::sample")
def sample_fake(logits, temperature, top_k, top_p, seed=None):
def sample_fake(logits, temperature, top_p, seed=None):
return logits.new_empty(logits.shape[:-1], dtype=torch.long)


# ---------------------------------------------------------------------
# Issue #20554: runtime MoE expert-sort for decode (MLX backend)
# ---------------------------------------------------------------------


@torch.library.custom_op("mlx::moe_gather_inputs", mutates_args=())
def moe_gather_inputs(
x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Branch on M on purpose — this is the executable spec the lowering
handler (ops.py) mirrors branch-for-branch. Sorting is an invertible
permutation (identical numerics either way); the two paths exist for
the lowering's sake, not the math's."""
N = x.shape[0]
if N > sort_cutoff: # SORTED path (handler: emit_sorted)
flat = expert_indices.flatten()
order = flat.argsort().to(torch.int32)
inv_order = order.argsort().to(torch.int32)
idx = flat[order].to(torch.int32) # [N*top_k]
x_input = x[(order // top_k).to(torch.int64)].unsqueeze(-2) # [N*top_k, 1, D]
sort_experts = torch.ones((), dtype=torch.int32)
else: # UNSORTED path (handler: emit_unsorted)
x_input = x.repeat_interleave(top_k, dim=0).unsqueeze(-2) # [N*top_k, 1, D]
idx = expert_indices.flatten().to(torch.int32) # [N*top_k]
sort_experts = torch.zeros((), dtype=torch.int32)
inv_order = torch.empty(0, dtype=torch.int32) # sentinel: never read
return x_input, idx, sort_experts, inv_order


@torch.library.register_fake("mlx::moe_gather_inputs")
def moe_gather_inputs_fake(
x: Tensor, expert_indices: Tensor, top_k: int, sort_cutoff: int
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Must NOT branch on M (symbolic SymInt under export — data-dependent
control flow on it is illegal). One shape for all M: the sorted-path
shape for x_input/idx/inv_order (issue #20554, section 3)."""
N = x.shape[0]
D = x.shape[-1]
NK = N * top_k
x_input = x.new_empty((NK, 1, D))
idx = expert_indices.new_empty((NK,), dtype=torch.int32)
sort_experts = x.new_empty((), dtype=torch.int32)
inv_order = x.new_empty((NK,), dtype=torch.int32)
return x_input, idx, sort_experts, inv_order


@torch.library.custom_op("mlx::moe_scatter_outputs", mutates_args=())
def moe_scatter_outputs(
down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int
) -> Tensor:
down = down.squeeze(-2) # [N*top_k, H]
if sort_experts.item(): # prefill: scatter back (handler: emit_then)
down = down[inv_order]
# decode: no scatter (inv_order is the unread sentinel) (handler: emit_else)
return down.reshape(down.shape[0] // top_k, top_k, -1).clone() # [N, top_k, H]
# .clone(): avoids the aliasing-on-leaf-op issue opcheck flags for
# the no-op (unsorted) reshape path -- not in the issue's pseudo-code
# verbatim, added per torch.library.opcheck's aliasing requirement
# for custom ops (mutates_args=() means outputs must not alias inputs).


@torch.library.register_fake("mlx::moe_scatter_outputs")
def moe_scatter_outputs_fake(
down: Tensor, sort_experts: Tensor, inv_order: Tensor, top_k: int
) -> Tensor:
"""Shape derived only from down + top_k -- no branching needed, no
dependency on inv_order's shape (issue #20554, section 3)."""
NK = down.shape[0]
H = down.shape[-1]
return down.new_empty((NK // top_k, top_k, H))
42 changes: 20 additions & 22 deletions backends/mlx/llm/switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"""

import logging
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -131,12 +132,15 @@ def forward(
self,
x: torch.Tensor,
indices: torch.Tensor,
sorted_indices: bool = False,
sorted_indices: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward without unsqueeze/squeeze — caller manages dimensions.

Used by UnfusedMoEExperts which passes x as [N, 1, 1, D]
and indices as [N, top_k] to handle all experts at once.

issue #20554: sorted_indices: bool -> Optional[Tensor] (same swap
as gather_mm/gather_qmm). Passed straight through to those ops.
"""
if not self._packed:
raise RuntimeError("SwitchLinear.pack() must be called before forward_raw.")
Expand Down Expand Up @@ -193,6 +197,7 @@ def __init__(
activation=None,
bias: bool = False,
fuse_gate_up: bool = False,
sort_cutoff: int = 1,
):
super().__init__()
if activation is None:
Expand All @@ -201,6 +206,9 @@ def __init__(
self.num_experts = num_experts
self.intermediate_size = intermediate_size
self.fuse_gate_up = fuse_gate_up
# issue #20554: static export-time threshold, compared against
# M=N inside moe_gather_inputs to decide sort/no-sort at runtime.
self.sort_cutoff = sort_cutoff

if fuse_gate_up:
self.gate_up_proj = SwitchLinear(
Expand All @@ -223,7 +231,6 @@ def forward(
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
top_k: int,
sort_experts: bool = False,
) -> torch.Tensor:
"""Forward pass through the gated MoE MLP.

Expand All @@ -232,25 +239,20 @@ def forward(
expert_weights: Routing weights [N, top_k] (already softmaxed)
expert_indices: Expert assignments [N, top_k]
top_k: Number of experts per token
sort_experts: Sort tokens by expert index for coalesced memory
access during prefill. No effect on decode (single token).

Returns:
Output tensor [N, D]

issue #20554: sort/no-sort is now a RUNTIME decision (M vs
self.sort_cutoff) made inside moe_gather_inputs, rather than a
compile-time bool baked into the exported .pte. The `sort_experts`
arg is therefore gone from this signature (was: bool = False) --
callers (e.g. mlx_source_transformations.py) configure the
threshold once via SwitchMLP(..., sort_cutoff=...) instead.
"""
N = x.shape[0]

if sort_experts:
flat_indices = expert_indices.flatten()
order = flat_indices.argsort().to(torch.int32)
inv_order = order.argsort().to(torch.int32)
sorted_idx = flat_indices[order].to(torch.int32)
x_sorted = x[(order // top_k).to(torch.int64)]
x_input = x_sorted.unsqueeze(-2)
idx = sorted_idx
else:
x_input = x.unsqueeze(-2).unsqueeze(-2)
idx = expert_indices
x_input, idx, sort_experts, inv_order = torch.ops.mlx.moe_gather_inputs(
x, expert_indices, top_k, self.sort_cutoff
)

if self.fuse_gate_up:
gate_up = self.gate_up_proj(x_input, idx, sorted_indices=sort_experts)
Expand All @@ -262,11 +264,7 @@ def forward(
h = self.activation(gate) * up
down = self.down_proj(h, idx, sorted_indices=sort_experts)

if sort_experts:
down = down.squeeze(-2)
down = down[inv_order].reshape(N, top_k, -1)
else:
down = down.squeeze(-2)
down = torch.ops.mlx.moe_scatter_outputs(down, sort_experts, inv_order, top_k)

return (down * expert_weights.unsqueeze(-1)).sum(dim=-2)

Expand Down
Loading
Loading