feat: add fused moe shared-expert and add-rmsnorm optimization#1353
feat: add fused moe shared-expert and add-rmsnorm optimization#1353blueswhen wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces performance optimizations for Qwen3.5 and Qwen3Next models, including fused residual RMSNorm, shared expert gating, and specialized Triton kernels. The review feedback identifies a critical race condition in moe_align_fused_kernel requiring a tl.debug_barrier(), and a sequence parallelism bug in Qwen3Next where a hardcoded all_reduce should be replaced with _tpsp_reduce. Additionally, the feedback suggests using Triton's built-in tl.sigmoid for cleaner code, asserting tensor contiguity before calling .view() to prevent runtime crashes, and removing unused dead code in gdn_decode_pack.py.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if ZERO_EXPERT_TOKEN_NUM: | ||
| expert_offs = tl.arange(0, BLOCK_EXPERT) | ||
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) |
There was a problem hiding this comment.
There is a critical Write-After-Write / Read-After-Write race condition here. Since moe_align_fused_kernel is launched with multiple warps (default num_warps=4), all warps in the thread block will execute the if ZERO_EXPERT_TOKEN_NUM: block and write 0 to expert_token_num_ptr concurrently. Because warps are scheduled independently, some warps might execute the tl.store (writing 0) after other warps have already executed tl.atomic_add (incrementing the count). This results in the incremented values being overwritten by 0, leading to incorrect token counts and silent data corruption or crashes in subsequent grouped GEMMs. Adding a barrier (tl.debug_barrier()) immediately after the tl.store ensures all zeroing writes are completed and visible before any warp proceeds to the atomic additions.
| if ZERO_EXPERT_TOKEN_NUM: | |
| expert_offs = tl.arange(0, BLOCK_EXPERT) | |
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) | |
| if ZERO_EXPERT_TOKEN_NUM: | |
| expert_offs = tl.arange(0, BLOCK_EXPERT) | |
| tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num) | |
| tl.debug_barrier() |
| o = layer._get_o_local(o, infer_state, layer_weight) | ||
| if layer.tp_world_size_ > 1: | ||
| all_reduce(o, group=infer_state.dist_group) |
There was a problem hiding this comment.
Hardcoding all_reduce here breaks sequence parallelism (SP) support for full attention layers. When SP is enabled, the reduction should be a reduce_scatter (which is handled automatically by layer._tpsp_reduce), not an all_reduce. Hardcoding all_reduce causes shape/value mismatches and incorrect results because each rank will get the sum of the sequence-parallel slices instead of its own sequence slice. Please use layer._tpsp_reduce to dynamically handle both TP and SP reductions correctly.
o = layer._get_o_local(o, infer_state, layer_weight)
o = layer._tpsp_reduce(o, infer_state)| mask=offs_dim < dim_end, | ||
| other=0.0, | ||
| ).to(tl.float32) | ||
| gate = 1.0 / (1.0 + tl.exp(-gate)) |
| shared = shared.view(token_num, hidden_dim) | ||
| gate = gate.view(token_num, gate.shape[-1]) | ||
| assert shared.is_contiguous() | ||
| assert gate.is_contiguous() |
There was a problem hiding this comment.
Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.
| shared = shared.view(token_num, hidden_dim) | |
| gate = gate.view(token_num, gate.shape[-1]) | |
| assert shared.is_contiguous() | |
| assert gate.is_contiguous() | |
| assert shared.is_contiguous() | |
| assert gate.is_contiguous() | |
| shared = shared.view(token_num, hidden_dim) | |
| gate = gate.view(token_num, gate.shape[-1]) |
| else: | ||
| gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) | ||
| hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32) | ||
| gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) |
| gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32) | ||
| else: | ||
| gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32) | ||
| gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals)) |
| def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | ||
| hidden_arg = hidden.view(-1, hidden.shape[-1]) | ||
| shared_arg = shared.view(-1, hidden.shape[-1]) | ||
| gate_arg = gate.view(-1, gate.shape[-1]) |
There was a problem hiding this comment.
Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.
| def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | |
| hidden_arg = hidden.view(-1, hidden.shape[-1]) | |
| shared_arg = shared.view(-1, hidden.shape[-1]) | |
| gate_arg = gate.view(-1, gate.shape[-1]) | |
| def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | |
| assert hidden.is_contiguous() | |
| assert shared.is_contiguous() | |
| assert gate.is_contiguous() | |
| hidden_arg = hidden.view(-1, hidden.shape[-1]) | |
| shared_arg = shared.view(-1, hidden.shape[-1]) | |
| gate_arg = gate.view(-1, gate.shape[-1]) |
| def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | ||
| x_arg = x.view(-1, x.shape[-1]) | ||
| gate_arg = gate.view(-1, gate.shape[-1]) | ||
| assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) |
There was a problem hiding this comment.
Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.
| def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | |
| x_arg = x.view(-1, x.shape[-1]) | |
| gate_arg = gate.view(-1, gate.shape[-1]) | |
| assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1]) | |
| def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor: | |
| assert x.is_contiguous() | |
| assert gate.is_contiguous() | |
| x_arg = x.view(-1, x.shape[-1]) | |
| gate_arg = gate.view(-1, gate.shape[-1]) |
| @torch.no_grad() | ||
| def pack_gdn_decode_inputs( | ||
| mixed_qkv: torch.Tensor, | ||
| z_raw: torch.Tensor, | ||
| a_raw: torch.Tensor, | ||
| b_raw: torch.Tensor, | ||
| num_k_heads: int, | ||
| head_k_dim: int, | ||
| num_v_heads: int, | ||
| head_v_dim: int, | ||
| ): | ||
| batch = mixed_qkv.shape[0] | ||
| q_dim = num_k_heads * head_k_dim | ||
| k_dim = q_dim | ||
| v_dim = num_v_heads * head_v_dim | ||
| gate_dim = num_v_heads | ||
|
|
||
| q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) | ||
| k = torch.empty_like(q) | ||
| v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device) | ||
| z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device) | ||
| a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device) | ||
| b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device) | ||
|
|
||
| block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim)) | ||
| block_gate = triton.next_power_of_2(gate_dim) | ||
| _pack_gdn_decode_kernel[(batch,)]( | ||
| mixed_qkv, | ||
| z_raw, | ||
| a_raw, | ||
| b_raw, | ||
| q, | ||
| k, | ||
| v, | ||
| z, | ||
| a, | ||
| b, | ||
| mixed_qkv.stride(0), | ||
| mixed_qkv.stride(1), | ||
| z_raw.stride(0), | ||
| z_raw.stride(1), | ||
| z_raw.stride(2), | ||
| a_raw.stride(0), | ||
| a_raw.stride(1), | ||
| b_raw.stride(0), | ||
| b_raw.stride(1), | ||
| q_dim, | ||
| k_dim, | ||
| v_dim, | ||
| gate_dim, | ||
| BLOCK_QKV=block_qkv, | ||
| BLOCK_GATE=block_gate, | ||
| num_warps=4, | ||
| ) | ||
| return q, k, v, z, a, b |
No description provided.