diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 769a8c1745..2db568bb2c 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -728,6 +728,20 @@ class SplashAttention(BaseModel): local_use_splash_scheduler: bool | None = Field(None, description="Use experimental local splash attention scheduler.") local_sa_fuse_reciprocal: bool | None = Field(None, description="Maps to local fuse_reciprocal in SplashConfig.") local_sa_use_base2_exp: bool | None = Field(None, description="Maps to local use_base2_exp in SplashConfig.") + experimental_sa_quant_q_fp8: bool | None = Field( + None, + description=( + "Experimental flag: If enabled, the Q tensor in splash attention is" + " quantized to jnp.float8_e4m3fn." + ), + ) + experimental_sa_quant_k_fp8: bool | None = Field( + None, + description=( + "Experimental flag: If enabled, the K tensor in splash attention is" + " quantized to jnp.float8_e4m3fn." + ), + ) use_max_logit_estimate: int = Field( -1, description="-1 means no estimate, any > 0 value will be used as max logit estimate", diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index d5f79a1a4b..1380d38be3 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -876,6 +876,10 @@ def mla_query_projection( # Query projection is scaled by self.softmax_scale to be consistent MaxText implementation. # DeepSeek v3 was doing it in attention score computation. query = jnp.concatenate([q_nope, q_pe], axis=-1) * self.softmax_scale + + if self.config.experimental_sa_quant_q_fp8: + query = query.astype(jnp.float8_e4m3fn) + query = self._maybe_shard_with_logical(query, query_logical_name) return query, low_rank_q @@ -899,6 +903,9 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode): key = jnp.concatenate([key_nope, key_rope], axis=-1) + if self.config.experimental_sa_quant_k_fp8: + key = key.astype(jnp.float8_e4m3fn) + key = self._maybe_shard_with_logical(key, key_logical_name) value = self._maybe_shard_with_logical(value, value_logical_name) return key, value