Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,20 @@ class SplashAttention(BaseModel):
local_use_splash_scheduler: bool | None = Field(None, description="Use experimental local splash attention scheduler.")
local_sa_fuse_reciprocal: bool | None = Field(None, description="Maps to local fuse_reciprocal in SplashConfig.")
local_sa_use_base2_exp: bool | None = Field(None, description="Maps to local use_base2_exp in SplashConfig.")
experimental_sa_quant_q_fp8: bool | None = Field(
None,
description=(
"Experimental flag: If enabled, the Q tensor in splash attention is"
" quantized to jnp.float8_e4m3fn."
),
)
experimental_sa_quant_k_fp8: bool | None = Field(
None,
description=(
"Experimental flag: If enabled, the K tensor in splash attention is"
" quantized to jnp.float8_e4m3fn."
),
)
use_max_logit_estimate: int = Field(
-1,
description="-1 means no estimate, any > 0 value will be used as max logit estimate",
Expand Down
7 changes: 7 additions & 0 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,10 @@ def mla_query_projection(
# Query projection is scaled by self.softmax_scale to be consistent MaxText implementation.
# DeepSeek v3 was doing it in attention score computation.
query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale

if self.config.experimental_sa_quant_q_fp8:
query = query.astype(jnp.float8_e4m3fn)

query = self._maybe_shard_with_logical(query, query_logical_name)
return query, low_rank_q

Expand All @@ -899,6 +903,9 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode):

key = jnp.concatenate([key_nope, key_rope], axis=-1)

if self.config.experimental_sa_quant_k_fp8:
key = key.astype(jnp.float8_e4m3fn)

key = self._maybe_shard_with_logical(key, key_logical_name)
value = self._maybe_shard_with_logical(value, value_logical_name)
return key, value
Expand Down
Loading