Skip to content

feat: add fused moe shared-expert and add-rmsnorm optimization#1353

Open
blueswhen wants to merge 1 commit into
mainfrom
opt_fused_moe
Open

feat: add fused moe shared-expert and add-rmsnorm optimization#1353
blueswhen wants to merge 1 commit into
mainfrom
opt_fused_moe

Conversation

@blueswhen

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for Qwen3.5 and Qwen3Next models, including fused residual RMSNorm, shared expert gating, and specialized Triton kernels. The review feedback identifies a critical race condition in moe_align_fused_kernel requiring a tl.debug_barrier(), and a sequence parallelism bug in Qwen3Next where a hardcoded all_reduce should be replaced with _tpsp_reduce. Additionally, the feedback suggests using Triton's built-in tl.sigmoid for cleaner code, asserting tensor contiguity before calling .view() to prevent runtime crashes, and removing unused dead code in gdn_decode_pack.py.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +231 to +233
if ZERO_EXPERT_TOKEN_NUM:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical Write-After-Write / Read-After-Write race condition here. Since moe_align_fused_kernel is launched with multiple warps (default num_warps=4), all warps in the thread block will execute the if ZERO_EXPERT_TOKEN_NUM: block and write 0 to expert_token_num_ptr concurrently. Because warps are scheduled independently, some warps might execute the tl.store (writing 0) after other warps have already executed tl.atomic_add (incrementing the count). This results in the incremented values being overwritten by 0, leading to incorrect token counts and silent data corruption or crashes in subsequent grouped GEMMs. Adding a barrier (tl.debug_barrier()) immediately after the tl.store ensures all zeroing writes are completed and visible before any warp proceeds to the atomic additions.

Suggested change
if ZERO_EXPERT_TOKEN_NUM:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)
if ZERO_EXPERT_TOKEN_NUM:
expert_offs = tl.arange(0, BLOCK_EXPERT)
tl.store(expert_token_num_ptr + expert_offs, 0, mask=expert_offs < expert_num)
tl.debug_barrier()

Comment thread lightllm/models/qwen3next/model.py Outdated
Comment on lines +133 to +135
o = layer._get_o_local(o, infer_state, layer_weight)
if layer.tp_world_size_ > 1:
all_reduce(o, group=infer_state.dist_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding all_reduce here breaks sequence parallelism (SP) support for full attention layers. When SP is enabled, the reduction should be a reduce_scatter (which is handled automatically by layer._tpsp_reduce), not an all_reduce. Hardcoding all_reduce causes shape/value mismatches and incorrect results because each rank will get the sum of the sequence-parallel slices instead of its own sequence slice. Please use layer._tpsp_reduce to dynamically handle both TP and SP reductions correctly.

                o = layer._get_o_local(o, infer_state, layer_weight)
                o = layer._tpsp_reduce(o, infer_state)

mask=offs_dim < dim_end,
other=0.0,
).to(tl.float32)
gate = 1.0 / (1.0 + tl.exp(-gate))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Triton provides a built-in tl.sigmoid function which is cleaner, more readable, and potentially better optimized than the manual sigmoid formula.

Suggested change
gate = 1.0 / (1.0 + tl.exp(-gate))
gate = tl.sigmoid(gate)

Comment on lines +113 to +116
shared = shared.view(token_num, hidden_dim)
gate = gate.view(token_num, gate.shape[-1])
assert shared.is_contiguous()
assert gate.is_contiguous()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.

Suggested change
shared = shared.view(token_num, hidden_dim)
gate = gate.view(token_num, gate.shape[-1])
assert shared.is_contiguous()
assert gate.is_contiguous()
assert shared.is_contiguous()
assert gate.is_contiguous()
shared = shared.view(token_num, hidden_dim)
gate = gate.view(token_num, gate.shape[-1])

else:
gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32)
hidden_vals = tl.load(hidden_ptrs, mask=mask, other=0.0).to(tl.float32)
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Triton provides a built-in tl.sigmoid function which is cleaner, more readable, and potentially better optimized than the manual sigmoid formula.

Suggested change
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))
gate_vals = tl.sigmoid(gate_vals)

gate_vals = tl.load(gate + row * stride_g_m).to(tl.float32)
else:
gate_vals = tl.load(gate + row * stride_g_m + offs * stride_g_n, mask=mask, other=0.0).to(tl.float32)
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Triton provides a built-in tl.sigmoid function which is cleaner, more readable, and potentially better optimized than the manual sigmoid formula.

Suggested change
gate_vals = 1.0 / (1.0 + tl.exp(-gate_vals))
gate_vals = tl.sigmoid(gate_vals)

Comment on lines +63 to +66
def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
hidden_arg = hidden.view(-1, hidden.shape[-1])
shared_arg = shared.view(-1, hidden.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.

Suggested change
def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
hidden_arg = hidden.view(-1, hidden.shape[-1])
shared_arg = shared.view(-1, hidden.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])
def add_shared_expert_gate_(hidden: torch.Tensor, shared: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
assert hidden.is_contiguous()
assert shared.is_contiguous()
assert gate.is_contiguous()
hidden_arg = hidden.view(-1, hidden.shape[-1])
shared_arg = shared.view(-1, hidden.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])

Comment on lines +90 to +93
def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
x_arg = x.view(-1, x.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])
assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1])

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .view() on potentially non-contiguous tensors can raise a RuntimeError. It is safer to assert contiguity before calling .view() to prevent unexpected runtime crashes.

Suggested change
def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
x_arg = x.view(-1, x.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])
assert gate_arg.shape[0] == x_arg.shape[0] and gate_arg.shape[1] in (1, x_arg.shape[1])
def sigmoid_mul_(x: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
assert x.is_contiguous()
assert gate.is_contiguous()
x_arg = x.view(-1, x.shape[-1])
gate_arg = gate.view(-1, gate.shape[-1])

Comment on lines +68 to +122
@torch.no_grad()
def pack_gdn_decode_inputs(
mixed_qkv: torch.Tensor,
z_raw: torch.Tensor,
a_raw: torch.Tensor,
b_raw: torch.Tensor,
num_k_heads: int,
head_k_dim: int,
num_v_heads: int,
head_v_dim: int,
):
batch = mixed_qkv.shape[0]
q_dim = num_k_heads * head_k_dim
k_dim = q_dim
v_dim = num_v_heads * head_v_dim
gate_dim = num_v_heads

q = torch.empty((batch, 1, num_k_heads, head_k_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device)
k = torch.empty_like(q)
v = torch.empty((batch, 1, num_v_heads, head_v_dim), dtype=mixed_qkv.dtype, device=mixed_qkv.device)
z = torch.empty((batch, num_v_heads, head_v_dim), dtype=z_raw.dtype, device=z_raw.device)
a = torch.empty((batch, gate_dim), dtype=a_raw.dtype, device=a_raw.device)
b = torch.empty((batch, gate_dim), dtype=b_raw.dtype, device=b_raw.device)

block_qkv = triton.next_power_of_2(max(q_dim, k_dim, v_dim))
block_gate = triton.next_power_of_2(gate_dim)
_pack_gdn_decode_kernel[(batch,)](
mixed_qkv,
z_raw,
a_raw,
b_raw,
q,
k,
v,
z,
a,
b,
mixed_qkv.stride(0),
mixed_qkv.stride(1),
z_raw.stride(0),
z_raw.stride(1),
z_raw.stride(2),
a_raw.stride(0),
a_raw.stride(1),
b_raw.stride(0),
b_raw.stride(1),
q_dim,
k_dim,
v_dim,
gate_dim,
BLOCK_QKV=block_qkv,
BLOCK_GATE=block_gate,
num_warps=4,
)
return q, k, v, z, a, b

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function pack_gdn_decode_inputs and its corresponding kernel _pack_gdn_decode_kernel are defined but never imported or used anywhere in the codebase. Removing this dead code will keep the codebase clean and maintainable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant