backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685
backends/mlx: runtime MoE expert-sort for decode (issue #20554)#20685AxelNoun wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20685
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 86ee91c with merge base 0f3303f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @AxelNoun! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
This PR needs a
|
Replace the compile-time sort_experts: bool flag in SwitchMLP with a runtime decision made inside two new custom ops (moe_gather_inputs, moe_scatter_outputs). A single exported .pte now handles both prefill (sorted, coalesced gather_mm) and decode (unsorted, no argsort overhead) without requiring separate exports. Changes: - schema.fbs: sorted_indices: bool -> IntOrVid (required) on GatherMmNode/GatherQmmNode; required fields moved before optionals - MLXInterpreter.h: resolve_int(n.sorted_indices, st) != 0 (cf. kth) - custom_ops.py: moe_gather_inputs, moe_scatter_outputs + register_fake; gather_mm/gather_qmm sorted_indices: bool -> Optional[Tensor] - op_helpers.py: emit_floordiv helper (alongside emit_ceil_div) - ops.py: _moe_gather_inputs_handler, _moe_scatter_outputs_handler; updated _gather_mm/_gather_qmm handlers for IntOrVid - switch.py: SwitchMLP gains sort_cutoff; forward replaces if/else block with the two new ops; SwitchLinear sorted_indices: bool -> Optional[Tensor] - mlx_source_transformations.py + export.py: sort_experts -> sort_cutoff - test_ops.py: MoeGatherInputsTest, MoeScatterOutputsTest with expected_node_counts; GatherMmTest/GatherQmmTest extended for sorted_indices=Tensor configs Test plan: - Windows: python backends/mlx/test/validate_moe_20554.py (all passed) - CI: test-mlx job on macos-14-xlarge (run_all_tests) Fixes pytorch#20554 PR authored with Claude. Co-authored-by: Cursor <cursoragent@cursor.com>
d709ef4 to
86ee91c
Compare
Summary
Replace the compile-time
sort_experts: boolflag inSwitchMLPwith a runtime decision inside two new custom ops (moe_gather_inputs,moe_scatter_outputs). A single exported.ptenow handles both prefill (sorted, coalescedgather_mm) and decode (unsorted, no argsort overhead) without separate exports.Key changes:
schema.fbs:sorted_indices: bool→IntOrVid(required) onGatherMmNode/GatherQmmNode; required fields before optionalsMLXInterpreter.h:resolve_int(n.sorted_indices, st) != 0(cf.kth)custom_ops.py:moe_gather_inputs,moe_scatter_outputs;gather_mm/gather_qmmsorted_indices: Optional[Tensor]ops.py: new MoE handlers + updated gather handlers forIntOrVidswitch.py:sort_cutoffreplaces compile-time sort branchtest_ops.py: MoE + GatherMm/GatherQmm tests withsorted_indices=TensorconfigsMLXLoader.{h,cpp}and FlatBuffer bindings are regenerated automatically bygenerate.py+flatcduring the CMake build on Mac CI — they are not included in this commit, per repo convention.Test plan
python backends/mlx/test/validate_moe_20554.py(all passed)test-mlxjob onmacos-14-xlarge(run_all_tests— coversgather_mm,gather_qmm,moe_gather_inputs,moe_scatter_outputs)Fixes #20554
PR authored with Claude.