diff --git a/BUILD.bazel b/BUILD.bazel index 7697c16f..db2cb21b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -690,6 +690,7 @@ cc_library( ":configs", ":flash_structs", ":kv_cache", + ":kv_transcoding", ":mat", ":matmul", ":matmul_env", diff --git a/gemma/activations.h b/gemma/activations.h index 0f94e056..78d304c6 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -150,6 +150,7 @@ struct AttentionActivations { size_t qkv_dim; AlignedBF16Vector bf16_queries; std::vector> int16_queries; + hwy::AlignedVector int8_queries; AlignedFloatVector float_queries; AlignedFloatVector q_scales; @@ -202,6 +203,7 @@ struct AttentionActivationsPtrs { split_flash_params(split_flash_params), bf16_queries(nullptr), int16_queries(nullptr), + int8_queries(nullptr), float_queries(nullptr), q_scales(nullptr), div_seq_len(static_cast(seq_len)), @@ -227,6 +229,7 @@ struct AttentionActivationsPtrs { inv_timescale_global = activations.inv_timescale_global; bf16_queries = &activations.bf16_queries; int16_queries = &activations.int16_queries; + int8_queries = &activations.int8_queries; float_queries = &activations.float_queries; q_scales = &activations.q_scales; } @@ -296,6 +299,7 @@ struct AttentionActivationsPtrs { sub_task_max_logits; AlignedBF16Vector* bf16_queries; std::vector>* int16_queries; + hwy::AlignedVector* int8_queries; AlignedFloatVector* float_queries; AlignedFloatVector* q_scales; // Inverse timescales for RoPE computation. diff --git a/gemma/configs.cc b/gemma/configs.cc index df432557..96ea0453 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -737,6 +737,7 @@ constexpr std::pair kAttentionImplNameToEnum[] = { {"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16}, {"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16}, {"flash_matrix_accumulation", AttentionImpl::kFlashMatrixAccumulation}, + {"int8_matrix_accumulation", AttentionImpl::kInt8MatrixAccumulation}, }; std::string GetAttentionImplName(AttentionImpl impl) { @@ -771,6 +772,8 @@ std::string KVEncodingToString(KVEncoding encoding) { return "Int8TwoTranspositions"; case KVEncoding::kBF16MatrixAccumulation: return "BF16MatrixAccumulation"; + case KVEncoding::kInt8MatrixAccumulation: + return "Int8MatrixAccumulation"; default: return "Unknown"; } diff --git a/gemma/configs.h b/gemma/configs.h index 474cdef1..2fa69312 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -91,6 +91,7 @@ enum class KVEncoding { kInt8 = 5, kInt8TwoTranspositions = 6, kBF16MatrixAccumulation = 7, + kInt8MatrixAccumulation = 8, }; // Returns a string representation of the KVEncoding. @@ -106,6 +107,7 @@ enum class AttentionImpl { kFlashTransposedQsBF16, kFlashTransposedQsInt16, kFlashMatrixAccumulation, + kInt8MatrixAccumulation, kSentinel, }; diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index a925581b..f64af120 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -331,7 +331,7 @@ HWY_INLINE void QDotKTile148BF16NotNative( const size_t kNF = hn::Lanes(df); const float* HWY_RESTRICT q_base[kVTileSize]; for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = reinterpret_cast(q + q_offsets[i]); + q_base[i] = HWY_RCAST_ALIGNED(const float*, q + q_offsets[i]); } const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { @@ -433,7 +433,7 @@ HWY_INLINE void QDotKTile148BF16Native( const size_t kNF = hn::Lanes(df); const float* HWY_RESTRICT q_base[kVTileSize]; for (size_t i = 0; i < kVTileSize; ++i) { - q_base[i] = reinterpret_cast(q + q_offsets[i]); + q_base[i] = HWY_RCAST_ALIGNED(const float*, q + q_offsets[i]); } const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF)); for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) { @@ -1356,9 +1356,9 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( // tile base can point to same tile as previous loop iteration, hence no // HWY_RESTRICT // KVs are unaligned and we only use unaligned loads in this implementation. - const KV_T* tile_base = - reinterpret_cast(kvs[current_kv_idx].RowBytes( - (position - current_kv_start_offset) / kTileSize)); + const KV_T* tile_base = HWY_RCAST_ALIGNED( + const KV_T*, kvs[current_kv_idx].RowBytes( + (position - current_kv_start_offset) / kTileSize)); const KV_T* v_tile = tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim; @@ -1396,7 +1396,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( // After end of the tile, we have kTileSize * 2 bfloat16 for the // microscaling scales for K and V. const BF16* microscaling_scales_k = - reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + HWY_RCAST_ALIGNED(const BF16*, tile_base + qkv_dim * 2 * kTileSize) + pos_in_tile; MultiplyByScale(df, microscaling_scales_k, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, @@ -1439,7 +1439,8 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( if constexpr (IsInt16() && kUseMicroScaling) { if (query_idx == 0) { // update only when needed const BF16* microscaling_scales_v = - reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + HWY_RCAST_ALIGNED(const BF16*, + tile_base + qkv_dim * 2 * kTileSize) + kTileSize + pos_in_tile; const PackedSpan scales_span = MakeConstSpan(microscaling_scales_v, 2 * hn::Lanes(df)); @@ -1456,7 +1457,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits( q_scales_s_ptr, max_v_scale); if constexpr (kUseMicroScaling) { const BF16* microscaling_scales_v = - reinterpret_cast(tile_base + qkv_dim * 2 * kTileSize) + + HWY_RCAST_ALIGNED(const BF16*, tile_base + qkv_dim * 2 * kTileSize) + kTileSize + pos_in_tile; MultiplyByScale(df, microscaling_scales_v, x_0_p_0, x_0_p_1, x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0, @@ -1570,6 +1571,9 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16( last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits); } +template +MatT GetKVTypeHelper(const hwy::Span>&); + void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( hwy::Span kvs, size_t q_count, const BF16* HWY_RESTRICT q_base, @@ -1578,9 +1582,30 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, float* HWY_RESTRICT max_logits) { CallUpcastedKVs(kvs, [&](const auto& kv_t) { - TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( - kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query, - att_cap, att_out, exp_denominator_sums, max_logits); + using KV_T = decltype(GetKVTypeHelper(kv_t)); + if constexpr (IsBF16()) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); + } + }); +} + +void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( + hwy::Span kvs, size_t q_count, + const int8_t* HWY_RESTRICT q_base, hwy::Span q_scales, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, const float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { + CallUpcastedKVs(kvs, [&](const auto& kv_t) { + using KV_T = decltype(GetKVTypeHelper(kv_t)); + if constexpr (IsInt8()) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + kv_t, q_count, q_base, q_scales, start_pos_per_query, + last_pos_per_query, att_cap, att_out, exp_denominator_sums, + max_logits); + } }); } diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 3c9f4fbe..602f15fc 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -83,6 +83,15 @@ namespace gcpp { MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ float* HWY_RESTRICT max_logits); \ \ + void \ + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( \ + hwy::Span kvs, size_t q_count, \ + const int8_t* HWY_RESTRICT q_base, hwy::Span q_scales, \ + hwy::Span start_pos_per_query, \ + hwy::Span last_pos_per_query, const float att_cap, \ + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, \ + float* HWY_RESTRICT max_logits); \ + \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/flash_attention_arm-inl.h b/gemma/flash_attention_arm-inl.h index 96222a46..84ddf5c5 100644 --- a/gemma/flash_attention_arm-inl.h +++ b/gemma/flash_attention_arm-inl.h @@ -8,7 +8,6 @@ #include #include #include -#include #include #include "gemma/flash_attention.h" @@ -33,8 +32,12 @@ #include "ops/ops-inl.h" #include "hwy/contrib/math/fast_math-inl.h" -#ifndef BENCHMARK_BLOCK_SIZE -#define BENCHMARK_BLOCK_SIZE 128 +#ifndef BENCHMARK_BLOCK_SIZE_BF16 +#define BENCHMARK_BLOCK_SIZE_BF16 128 +#endif + +#ifndef BENCHMARK_BLOCK_SIZE_INT8 +#define BENCHMARK_BLOCK_SIZE_INT8 512 #endif HWY_BEFORE_NAMESPACE(); @@ -121,131 +124,153 @@ struct TileAttentionGroupParams { } }; -template +template HWY_INLINE hn::Vec LoadAndDuplicateQueries(D d, const T* HWY_RESTRICT q_ptr) { -#if HWY_HAVE_CONSTEXPR_LANES - constexpr size_t N = hn::Lanes(d); - if constexpr (N <= 8) { + if constexpr (kRegBytes <= 16) { return hn::LoadU(d, q_ptr); } else { return hn::LoadDup128(d, q_ptr); } -#else - return hn::LoadDup128(d, q_ptr); -#endif } -template > +template +HWY_INLINE void Accumulate4x4Grid(D_ACC d_acc, LoadA load_A, V_IN B0, V_IN B1, + V_IN B2, V_IN B3, V_ACC& acc00, V_ACC& acc01, + V_ACC& acc02, V_ACC& acc03, V_ACC& acc10, + V_ACC& acc11, V_ACC& acc12, V_ACC& acc13, + V_ACC& acc20, V_ACC& acc21, V_ACC& acc22, + V_ACC& acc23, V_ACC& acc30, V_ACC& acc31, + V_ACC& acc32, V_ACC& acc33) { + if constexpr (kNumQueries >= 1) { + const V_IN A = load_A(0); + acc00 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B0, acc00); + acc01 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B1, acc01); + acc02 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B2, acc02); + acc03 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B3, acc03); + } + if constexpr (kNumQueries >= 3) { + const V_IN A = load_A(1); + acc10 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B0, acc10); + acc11 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B1, acc11); + acc12 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B2, acc12); + acc13 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B3, acc13); + } + if constexpr (kNumQueries >= 5) { + const V_IN A = load_A(2); + acc20 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B0, acc20); + acc21 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B1, acc21); + acc22 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B2, acc22); + acc23 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B3, acc23); + } + if constexpr (kNumQueries >= 7) { + const V_IN A = load_A(3); + acc30 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B0, acc30); + acc31 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B1, acc31); + acc32 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B2, acc32); + acc33 = PerBlock2x2MatMulMaybeEmulate(d_acc, A, B3, acc33); + } +} + +template HWY_INLINE void QDotKTilexUpTo8MatrixAccumulation( DF df, const T_IN* HWY_RESTRICT q_group, const T_IN* tile_base, - size_t current_pos, size_t qkv_dim, VF& C00, VF& C01, VF& C02, VF& C03, - VF& C10, VF& C11, VF& C12, VF& C13, VF& C20, VF& C21, VF& C22, VF& C23, - VF& C30, VF& C31, VF& C32, VF& C33) { - using D_ACC = hwy::If(), hn::FixedTag, DF>; + size_t current_pos, size_t qkv_dim, size_t stride_q, VF& C00, VF& C01, + VF& C02, VF& C03, VF& C10, VF& C11, VF& C12, VF& C13, VF& C20, VF& C21, + VF& C22, VF& C23, VF& C30, VF& C31, VF& C32, VF& C33) { + using D_ACC = hwy::If(), hn::Repartition, DF>; const D_ACC d_acc; - using VecAcc = hn::Vec; - - VecAcc acc00 = hn::Zero(d_acc), acc01 = hn::Zero(d_acc), - acc02 = hn::Zero(d_acc), acc03 = hn::Zero(d_acc); - VecAcc acc10 = hn::Zero(d_acc), acc11 = hn::Zero(d_acc), - acc12 = hn::Zero(d_acc), acc13 = hn::Zero(d_acc); - VecAcc acc20 = hn::Zero(d_acc), acc21 = hn::Zero(d_acc), - acc22 = hn::Zero(d_acc), acc23 = hn::Zero(d_acc); - VecAcc acc30 = hn::Zero(d_acc), acc31 = hn::Zero(d_acc), - acc32 = hn::Zero(d_acc), acc33 = hn::Zero(d_acc); - - using D_INPUT = hn::Repartition; + using V_ACC = hn::Vec; + + V_ACC acc00 = hn::Zero(d_acc), acc01 = hn::Zero(d_acc), + acc02 = hn::Zero(d_acc), acc03 = hn::Zero(d_acc); + V_ACC acc10 = hn::Zero(d_acc), acc11 = hn::Zero(d_acc), + acc12 = hn::Zero(d_acc), acc13 = hn::Zero(d_acc); + V_ACC acc20 = hn::Zero(d_acc), acc21 = hn::Zero(d_acc), + acc22 = hn::Zero(d_acc), acc23 = hn::Zero(d_acc); + V_ACC acc30 = hn::Zero(d_acc), acc31 = hn::Zero(d_acc), + acc32 = hn::Zero(d_acc), acc33 = hn::Zero(d_acc); + + using D_INPUT = hn::ScalableTag; const D_INPUT d_input; using VecInput = hn::Vec; - HWY_LANES_CONSTEXPR size_t step_size = hn::Lanes(d_input); - HWY_LANES_CONSTEXPR size_t ch_step = 4; // Always 4 for 8x4 layout - - size_t g0 = (current_pos / 8) % 4; - const T_IN* k_ptr0; - const T_IN* k_ptr1; - const T_IN* k_ptr2; - const T_IN* k_ptr3; - - if (step_size == 32) { - k_ptr0 = tile_base + 0; - k_ptr1 = tile_base + 32; - k_ptr2 = tile_base + 64; - k_ptr3 = tile_base + 96; - } else if (step_size == 16) { - k_ptr0 = tile_base + g0 * 32; - k_ptr1 = tile_base + (g0 + 1) * 32; - k_ptr2 = nullptr; - k_ptr3 = nullptr; - } else { // step_size == 8 - k_ptr0 = tile_base + g0 * 32; - k_ptr1 = nullptr; - k_ptr2 = nullptr; - k_ptr3 = nullptr; - } + constexpr size_t kStepSize = kRegBytes / sizeof(T_IN); + constexpr size_t step_size = kStepSize; + HWY_LANES_CONSTEXPR size_t ch_step = IsInt8() ? 8 : 4; + // Variables for BF16 path + const T_IN* k_ptr0_bf16 = nullptr; + const T_IN* k_ptr1_bf16 = nullptr; + const T_IN* k_ptr2_bf16 = nullptr; + const T_IN* k_ptr3_bf16 = nullptr; size_t ch_base_k = 0; size_t ch_base_q = 0; + + if constexpr (!IsInt8()) { + if constexpr (kStepSize == 32) { + k_ptr0_bf16 = tile_base + 0; + k_ptr1_bf16 = tile_base + 32; + k_ptr2_bf16 = tile_base + 64; + k_ptr3_bf16 = tile_base + 96; + } else if constexpr (kStepSize == 16) { + const size_t g0 = (current_pos / 8) % 4; + k_ptr0_bf16 = tile_base + g0 * 32; + k_ptr1_bf16 = tile_base + (g0 + 1) * 32; + } else { // kStepSize == 8 + const size_t g0 = (current_pos / 8) % 4; + k_ptr0_bf16 = tile_base + g0 * 32; + } + } + for (size_t ch_base = 0; ch_base < qkv_dim; ch_base += ch_step) { VecInput B0, B1, B2, B3; - if (step_size == 32) { // 512-bit native path - B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); - B1 = hn::LoadU(d_input, k_ptr1 + ch_base_k); - B2 = hn::LoadU(d_input, k_ptr2 + ch_base_k); - B3 = hn::LoadU(d_input, k_ptr3 + ch_base_k); - } else if (step_size == 16) { // 256-bit fallback path - B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); - B1 = hn::LoadU(d_input, k_ptr0 + ch_base_k + 16); // Group 0, second half - B2 = hn::LoadU(d_input, k_ptr1 + ch_base_k); - B3 = hn::LoadU(d_input, k_ptr1 + ch_base_k + 16); // Group 1, second half - } else if (step_size == 8) { // 128-bit fallback path - B0 = hn::LoadU(d_input, k_ptr0 + ch_base_k); - B1 = hn::LoadU(d_input, - k_ptr0 + ch_base_k + 8); // Group 0, second quarter - B2 = hn::LoadU(d_input, - k_ptr0 + ch_base_k + 16); // Group 0, third quarter - B3 = hn::LoadU(d_input, - k_ptr0 + ch_base_k + 24); // Group 0, fourth quarter - } - if constexpr (kNumQueries >= 1) { - const auto A0 = LoadAndDuplicateQueries(d_input, q_group + ch_base_q); - acc00 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B0, acc00); - acc01 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B1, acc01); - acc02 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B2, acc02); - acc03 = PerBlock2x2MatMulMaybeEmulate(d_acc, A0, B3, acc03); - } - if constexpr (kNumQueries >= 3) { - const auto A1 = - LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 2 + ch_base_q); - acc10 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B0, acc10); - acc11 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B1, acc11); - acc12 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B2, acc12); - acc13 = PerBlock2x2MatMulMaybeEmulate(d_acc, A1, B3, acc13); - } - if constexpr (kNumQueries >= 5) { - const auto A2 = - LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 4 + ch_base_q); - acc20 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B0, acc20); - acc21 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B1, acc21); - acc22 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B2, acc22); - acc23 = PerBlock2x2MatMulMaybeEmulate(d_acc, A2, B3, acc23); - } - if constexpr (kNumQueries >= 7) { - const auto A3 = - LoadAndDuplicateQueries(d_input, q_group + qkv_dim * 6 + ch_base_q); - acc30 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B0, acc30); - acc31 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B1, acc31); - acc32 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B2, acc32); - acc33 = PerBlock2x2MatMulMaybeEmulate(d_acc, A3, B3, acc33); + if constexpr (IsInt8()) { + // INT8 Path: 8x8 block layout + const size_t token_in_tile = current_pos % 32; + const T_IN* k_ptr = tile_base + ch_base * 32 + token_in_tile * 8; + + B0 = hn::LoadU(d_input, k_ptr + 0 * step_size); + B1 = hn::LoadU(d_input, k_ptr + 1 * step_size); + B2 = hn::LoadU(d_input, k_ptr + 2 * step_size); + B3 = hn::LoadU(d_input, k_ptr + 3 * step_size); + } else { + // BF16 Path: original loading logic + if constexpr (kStepSize == 32) { // 512-bit native path + B0 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k); + B1 = hn::LoadU(d_input, k_ptr1_bf16 + ch_base_k); + B2 = hn::LoadU(d_input, k_ptr2_bf16 + ch_base_k); + B3 = hn::LoadU(d_input, k_ptr3_bf16 + ch_base_k); + } else if constexpr (kStepSize == 16) { // 256-bit fallback path + B0 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k); + B1 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k + 16); + B2 = hn::LoadU(d_input, k_ptr1_bf16 + ch_base_k); + B3 = hn::LoadU(d_input, k_ptr1_bf16 + ch_base_k + 16); + } else { // kStepSize == 8 (128-bit fallback path) + B0 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k); + B1 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k + 8); + B2 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k + 16); + B3 = hn::LoadU(d_input, k_ptr0_bf16 + ch_base_k + 24); + } + + ch_base_k += ch_step * 32; } - ch_base_k += ch_step * 32; // Stride 128 elements for 8x4 layout + + auto load_A = [&](size_t idx) HWY_ATTR { + return LoadAndDuplicateQueries( + d_input, q_group + idx * 2 * stride_q + ch_base_q); + }; + + Accumulate4x4Grid( + d_acc, load_A, B0, B1, B2, B3, acc00, acc01, acc02, acc03, acc10, acc11, + acc12, acc13, acc20, acc21, acc22, acc23, acc30, acc31, acc32, acc33); + ch_base_q += ch_step * 2; } - auto convert_and_reduce = [&](VF& C, VecAcc acc) HWY_ATTR { + auto convert_and_reduce = [&](VF& C, V_ACC acc) HWY_ATTR { if constexpr (!IsInt8()) { C = acc; } else { @@ -292,10 +317,255 @@ HWY_INLINE V ConcatUpperUpper_VLA(DF df, V q1, V q0) { return hn::BitCast(df, interleaved); } -template +HWY_INLINE void QuantizeAndPackSoftmaxProbs( + DF_T df, DBF_T dbf, size_t q_base_idx, size_t actual_q_count, + size_t actual_block_size, size_t qkv_dim, + const float* HWY_RESTRICT q_scales_new, + const float* HWY_RESTRICT softmax_buf, const GroupInfo* group_infos, + int8_t* HWY_RESTRICT q_weights_buf, float* HWY_RESTRICT w_scales_buf) { + namespace hn = hwy::HWY_NAMESPACE; + const hn::Full128 df_4; + const hn::Full128 dbf8; + const hn::Full64 dbf4; + const hn::Full64 di16_4; + const hn::Full128 di16_8; + const hn::Full64 di8_8; + const hn::Full128 di8_16; + + using VF4 = hn::Vec; + using VBF8 = hn::Vec; + using VI16_4 = hn::Vec; + using VI16_8 = hn::Vec; + using VI8_8 = hn::Vec; + using VI8_16 = hn::Vec; + using VI32 = hn::Vec>; + + const size_t num_groups = actual_block_size / 8; + const size_t num_qp = 4; + constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE_INT8; + + // Process 2 queries at a time to reduce L1 working set and avoid large macros + HWY_ALIGN float w_buf[2 * kBlockSize]; + + for (size_t qp = 0; qp < num_qp; ++qp) { + size_t q0 = qp * 2; + size_t q1 = q0 + 1; + + VF4 max_val0 = hn::Zero(df_4); + VF4 max_val1 = hn::Zero(df_4); + + float q0_scale = (q0 < actual_q_count) ? q_scales_new[q0] : 0.0f; + float q1_scale = (q1 < actual_q_count) ? q_scales_new[q1] : 0.0f; + const VF4 q0_scale_vec = hn::Set(df_4, q0_scale); + const VF4 q1_scale_vec = hn::Set(df_4, q1_scale); + + // Fallback to safe memory block (multiplied by 0.0 later) to avoid + // out-of-bounds branches + const float* softmax_q0 = (q0 < actual_q_count) + ? softmax_buf + (q_base_idx + q0) * kBlockSize + : softmax_buf; + const float* softmax_q1 = (q1 < actual_q_count) + ? softmax_buf + (q_base_idx + q1) * kBlockSize + : softmax_buf; + + // Pass 1: Compute final unquantized weights w = softmax * q_scale * + // v_scale, and find max + for (size_t g = 0; g < num_groups; ++g) { + const auto& gi = group_infos[g]; + const VBF8 v_scale_bf16 = hn::LoadU(dbf8, gi.v_scales + gi.token_in_tile); + const VF4 v_scale_f_lo = hn::PromoteLowerTo(df_4, v_scale_bf16); + const VF4 v_scale_f_hi = hn::PromoteUpperTo(df_4, v_scale_bf16); + + // Query 0 + const VF4 s0_lo = hn::LoadU(df_4, softmax_q0 + g * 8); + const VF4 s0_hi = hn::LoadU(df_4, softmax_q0 + g * 8 + 4); + const VF4 qv0_lo = hn::Mul(q0_scale_vec, v_scale_f_lo); + const VF4 qv0_hi = hn::Mul(q0_scale_vec, v_scale_f_hi); + const VF4 w0_lo = hn::Mul(s0_lo, qv0_lo); + const VF4 w0_hi = hn::Mul(s0_hi, qv0_hi); + max_val0 = hn::Max(max_val0, hn::Max(w0_lo, w0_hi)); + hn::StoreU(w0_lo, df_4, w_buf + g * 8); + hn::StoreU(w0_hi, df_4, w_buf + g * 8 + 4); + + // Query 1 + const VF4 s1_lo = hn::LoadU(df_4, softmax_q1 + g * 8); + const VF4 s1_hi = hn::LoadU(df_4, softmax_q1 + g * 8 + 4); + const VF4 qv1_lo = hn::Mul(q1_scale_vec, v_scale_f_lo); + const VF4 qv1_hi = hn::Mul(q1_scale_vec, v_scale_f_hi); + const VF4 w1_lo = hn::Mul(s1_lo, qv1_lo); + const VF4 w1_hi = hn::Mul(s1_hi, qv1_hi); + max_val1 = hn::Max(max_val1, hn::Max(w1_lo, w1_hi)); + hn::StoreU(w1_lo, df_4, w_buf + kBlockSize + g * 8); + hn::StoreU(w1_hi, df_4, w_buf + kBlockSize + g * 8 + 4); + } + + float global_max0 = hn::ReduceMax(df_4, max_val0); + float global_max1 = hn::ReduceMax(df_4, max_val1); + + w_scales_buf[q0] = global_max0 / 127.0f; + w_scales_buf[q1] = global_max1 / 127.0f; + + float inv_scale0 = (global_max0 > 1e-30f) ? 127.0f / global_max0 : 0.0f; + float inv_scale1 = (global_max1 > 1e-30f) ? 127.0f / global_max1 : 0.0f; + + const VF4 inv_vec0 = hn::Set(df_4, inv_scale0); + const VF4 inv_vec1 = hn::Set(df_4, inv_scale1); + + // Pass 2: Finalize quantization & packed formatting + for (size_t g = 0; g < num_groups; ++g) { + // Query 0 + const VF4 w0_lo = hn::LoadU(df_4, w_buf + g * 8); + const VF4 w0_hi = hn::LoadU(df_4, w_buf + g * 8 + 4); + const VI32 w0_lo_i32 = hn::NearestInt(hn::Mul(w0_lo, inv_vec0)); + const VI32 w0_hi_i32 = hn::NearestInt(hn::Mul(w0_hi, inv_vec0)); + const VI16_8 w0_i16 = hn::OrderedDemote2To(di16_8, w0_lo_i32, w0_hi_i32); + + // Query 1 + const VF4 w1_lo = hn::LoadU(df_4, w_buf + kBlockSize + g * 8); + const VF4 w1_hi = hn::LoadU(df_4, w_buf + kBlockSize + g * 8 + 4); + const VI32 w1_lo_i32 = hn::NearestInt(hn::Mul(w1_lo, inv_vec1)); + const VI32 w1_hi_i32 = hn::NearestInt(hn::Mul(w1_hi, inv_vec1)); + const VI16_8 w1_i16 = hn::OrderedDemote2To(di16_8, w1_lo_i32, w1_hi_i32); + + // Write tightly grouped chunk for the native HW path to ingest directly + const VI8_16 w_i8_16 = hn::OrderedDemote2To(di8_16, w0_i16, w1_i16); + int8_t* dst = q_weights_buf + g * (num_qp * 16) + qp * 16; + hn::StoreU(w_i8_16, di8_16, dst); + } + } +} + +template +HWY_INLINE void TileFlashAttentionSVBlockInt8( + size_t q_base_idx, size_t qkv_dim, const float* HWY_RESTRICT scales_old, + float* HWY_RESTRICT C_accumulators, const GroupInfo* group_infos, + size_t num_groups, const int8_t* HWY_RESTRICT q_weights_pre, + const float* HWY_RESTRICT w_scales_pre) { + namespace hn = hwy::HWY_NAMESPACE; + using DI8 = hn::Full128; + const DI8 di8; + using VI8 = hn::Vec; + + using DI32 = hn::Full128; + const DI32 di32; + using VI32 = hn::Vec; + + using DF4 = hn::Full128; + const DF4 df4; + using VF4 = hn::Vec; + + // 1. Scale the old accumulator values in C_accumulators + for (size_t q = 0; q < kNumQueries; ++q) { + float* HWY_RESTRICT out = C_accumulators + (q_base_idx + q) * qkv_dim; + const VF4 s = hn::Set(df4, scales_old[q]); + for (size_t d = 0; d < qkv_dim; d += 4) { + const VF4 old = hn::LoadU(df4, out + d); + hn::StoreU(hn::Mul(old, s), df4, out + d); + } + } + + // 2. Loop over channel groups (each group has 8 channels) + for (size_t ch_base = 0; ch_base < qkv_dim; ch_base += 8) { + size_t ch_g = ch_base / 8; + + // Initialize accumulators: up to 4 query pairs x 4 channel pairs + VI32 acc00 = hn::Zero(di32), acc01 = hn::Zero(di32), + acc02 = hn::Zero(di32), acc03 = hn::Zero(di32); + VI32 acc10 = hn::Zero(di32), acc11 = hn::Zero(di32), + acc12 = hn::Zero(di32), acc13 = hn::Zero(di32); + VI32 acc20 = hn::Zero(di32), acc21 = hn::Zero(di32), + acc22 = hn::Zero(di32), acc23 = hn::Zero(di32); + VI32 acc30 = hn::Zero(di32), acc31 = hn::Zero(di32), + acc32 = hn::Zero(di32), acc33 = hn::Zero(di32); + + // 3. Accumulate over steps in int32 + for (size_t g = 0; g < num_groups; ++g) { + const auto& gi = group_infos[g]; + size_t tg8 = gi.token_in_tile / 8; + const int8_t* v_ptr = gi.v_tile_base + (ch_g * 4 + tg8) * 64; + + // Pre-load V values for the 4 channel pairs + VI8 B0 = hn::LoadU(di8, v_ptr + 0 * 16); + VI8 B1 = hn::LoadU(di8, v_ptr + 1 * 16); + VI8 B2 = hn::LoadU(di8, v_ptr + 2 * 16); + VI8 B3 = hn::LoadU(di8, v_ptr + 3 * 16); + + // Load Q weights and accumulate (reused across channel groups) + // The writer (QuantizeAndPackSoftmaxProbs) always uses a stride of 64 + // bytes (4 query pairs) regardless of kNumQueries. + const int8_t* q_w_ptr = q_weights_pre + g * 64; + + auto load_A = [&](size_t idx) HWY_ATTR { + return hn::LoadU(di8, q_w_ptr + idx * 16); + }; + + Accumulate4x4Grid(di32, load_A, B0, B1, B2, B3, acc00, + acc01, acc02, acc03, acc10, acc11, acc12, + acc13, acc20, acc21, acc22, acc23, acc30, + acc31, acc32, acc33); + } + + // 4. Dequantize and write back + auto dequant_and_store = [&](size_t qp, VI32 acc0, VI32 acc1, VI32 acc2, + VI32 acc3) HWY_ATTR { + size_t q0 = qp * 2; + size_t q1 = q0 + 1; + + float sq0 = w_scales_pre[q0]; + float sq1 = (q1 < kNumQueries) ? w_scales_pre[q1] : 0.0f; + + alignas(16) float s_arr[4] = {sq0, sq0, sq1, sq1}; + VF4 scale_vec = hn::LoadU(df4, s_arr); + + VF4 C0_f = hn::Mul(hn::ConvertTo(df4, acc0), scale_vec); + VF4 C1_f = hn::Mul(hn::ConvertTo(df4, acc1), scale_vec); + VF4 C2_f = hn::Mul(hn::ConvertTo(df4, acc2), scale_vec); + VF4 C3_f = hn::Mul(hn::ConvertTo(df4, acc3), scale_vec); + + // Reconstruct contiguous float channel vectors + VF4 Q0_lo = ConcatLowerLower_VLA(df4, C1_f, C0_f); + VF4 Q0_hi = ConcatLowerLower_VLA(df4, C3_f, C2_f); + VF4 Q1_lo = ConcatUpperUpper_VLA(df4, C1_f, C0_f); + VF4 Q1_hi = ConcatUpperUpper_VLA(df4, C3_f, C2_f); + + if (q0 < kNumQueries) { + float* out = C_accumulators + (q_base_idx + q0) * qkv_dim + ch_base; + hn::StoreU(hn::Add(hn::LoadU(df4, out + 0), Q0_lo), df4, out + 0); + hn::StoreU(hn::Add(hn::LoadU(df4, out + 4), Q0_hi), df4, out + 4); + } + if (q1 < kNumQueries) { + float* out = C_accumulators + (q_base_idx + q1) * qkv_dim + ch_base; + hn::StoreU(hn::Add(hn::LoadU(df4, out + 0), Q1_lo), df4, out + 0); + hn::StoreU(hn::Add(hn::LoadU(df4, out + 4), Q1_hi), df4, out + 4); + } + }; + + if constexpr (kNumQueries >= 1) { + dequant_and_store(0, acc00, acc01, acc02, acc03); + } + if constexpr (kNumQueries >= 3) { + dequant_and_store(1, acc10, acc11, acc12, acc13); + } + if constexpr (kNumQueries >= 5) { + dequant_and_store(2, acc20, acc21, acc22, acc23); + } + if constexpr (kNumQueries >= 7) { + dequant_and_store(3, acc30, acc31, acc32, acc33); + } + } +} + +template , typename DBF_T = hn::ScalableTag, typename KV_T> -HWY_INLINE void TileFlashAttentionSVBlock( +HWY_INLINE void TileFlashAttentionSVBlockBF16( size_t q_base_idx, size_t position, size_t current_kv_start_offset, size_t current_kv_idx, size_t actual_steps, size_t qkv_dim, const float* HWY_RESTRICT scales_old, @@ -312,16 +582,36 @@ HWY_INLINE void TileFlashAttentionSVBlock( using VF = hn::Vec; using VF4 = hn::Vec; - constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE; + constexpr size_t kBf16Lanes = kRegBytes / sizeof(BF16); + constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE_BF16; + + using KV_PTR_T = const KV_T*; + // Dynamically scale the tile pointer buffer with the block size. + constexpr size_t kMaxTiles = + (kBlockSize + KVCache::kTileSize - 1) / KVCache::kTileSize; + KV_PTR_T v_ptrs[kMaxTiles]; + size_t start_tile_idx = + (position - current_kv_start_offset) / KVCache::kTileSize; + + // Optimized tile pointer loading + size_t num_tiles_in_block = + (actual_steps * kBf16Lanes + KVCache::kTileSize - 1) / KVCache::kTileSize; + for (size_t t = 0; t < num_tiles_in_block; ++t) { + const KV_T* tile_base = HWY_RCAST_ALIGNED( + const KV_T*, kvs[current_kv_idx].RowBytes(start_tile_idx + t)); + v_ptrs[t] = tile_base + qkv_dim * 32; + } // Pre-pack Q into BF16 to avoid scaling and demoting in the inner loop. - const size_t kNumPackedVectors = ((kNumQueries + 1) / 2) * 2; - HWY_ALIGN BF16 q_packed[kBlockSize * 8 * hn::MaxLanes(dbf)]; + constexpr size_t kNumPackedVectors = ((kNumQueries + 1) / 2) * 2; + // Remove the redundant MaxLanes multiplier to reduce stack usage by up to + // 32x. + HWY_ALIGN BF16 q_packed[kBlockSize * kNumPackedVectors]; for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { auto load_scaled_q = [&](size_t qp, VF& v0, VF& v1) HWY_ATTR { const float* ptr = - softmax_buf + (q_base_idx + qp) * kBlockSize + step_idx * kStepSize; + softmax_buf + (q_base_idx + qp) * kBlockSize + step_idx * kBf16Lanes; VF qs = hn::Set(df, q_scales_new[qp]); v0 = hn::Mul(hn::LoadU(df, ptr + 0), qs); v1 = hn::Mul(hn::LoadU(df, ptr + hn::Lanes(df)), qs); @@ -347,8 +637,8 @@ HWY_INLINE void TileFlashAttentionSVBlock( auto pack_and_store_pair = [&](size_t pair_idx, VF ql0, VF qh0, VF ql1, VF qh1) HWY_ATTR { - BF16* dst = q_packed + step_idx * kNumPackedVectors * kStepSize + - pair_idx * 2 * kStepSize; + BF16* dst = q_packed + step_idx * kNumPackedVectors * kBf16Lanes + + pair_idx * 2 * kBf16Lanes; using D64 = hn::Repartition; const D64 d64; using D64_half = hn::Half; @@ -363,7 +653,7 @@ HWY_INLINE void TileFlashAttentionSVBlock( auto qh1_bf = hn::DemoteTo(dbf_half, qh1); hn::Vec A0, A1; - if constexpr (kStepSize > 8) { + if constexpr (kRegBytes > 16) { auto ql0_64 = hn::BitCast(d64_half, ql0_bf); auto ql1_64 = hn::BitCast(d64_half, ql1_bf); // This interleaves within 128-bit block so it's fast. @@ -382,7 +672,7 @@ HWY_INLINE void TileFlashAttentionSVBlock( } hn::StoreU(A0, dbf, dst + 0); - hn::StoreU(A1, dbf, dst + kStepSize); + hn::StoreU(A1, dbf, dst + kBf16Lanes); }; if constexpr (kNumQueries >= 1) @@ -395,34 +685,21 @@ HWY_INLINE void TileFlashAttentionSVBlock( pack_and_store_pair(3, q6_l, q6_h, q7_l, q7_h); } - // Pre-compute V tile pointers to avoid row lookups in the inner loop. - const BF16* v_ptrs[8]; - size_t start_tile_idx = - (position - current_kv_start_offset) / KVCache::kTileSize; - - // Optimized tile pointer loading - size_t num_tiles_in_block = - (actual_steps * kStepSize + KVCache::kTileSize - 1) / KVCache::kTileSize; - for (size_t t = 0; t < num_tiles_in_block; ++t) { - const BF16* tile_base = reinterpret_cast( - kvs[current_kv_idx].RowBytes(start_tile_idx + t)); - v_ptrs[t] = tile_base + qkv_dim * 32; - } - // Step-Dependent Pre-computation - const BF16* step_q_ptrs[BENCHMARK_BLOCK_SIZE / 8]; - const BF16* step_v_tiles[BENCHMARK_BLOCK_SIZE / 8]; - size_t step_offsets_even[BENCHMARK_BLOCK_SIZE / 8]; - size_t step_offsets_odd[BENCHMARK_BLOCK_SIZE / 8]; + const BF16* step_q_ptrs[kBlockSize / 8]; + const BF16* step_v_tiles[kBlockSize / 8]; + size_t step_offsets_even[kBlockSize / 8]; + size_t step_offsets_odd[kBlockSize / 8]; for (size_t step_idx = 0; step_idx < actual_steps; ++step_idx) { - step_q_ptrs[step_idx] = q_packed + step_idx * kNumPackedVectors * kStepSize; - size_t step_pos = position + step_idx * kStepSize; + step_q_ptrs[step_idx] = + q_packed + step_idx * kNumPackedVectors * kBf16Lanes; + size_t step_pos = position + step_idx * kBf16Lanes; size_t global_token_pos = step_pos - current_kv_start_offset; size_t tile_idx = global_token_pos / 32 - start_tile_idx; - step_v_tiles[step_idx] = v_ptrs[tile_idx]; + step_v_tiles[step_idx] = HWY_RCAST_ALIGNED(const BF16*, v_ptrs[tile_idx]); size_t t = step_pos % 32; - size_t t_odd = t + kStepSize / 2; + size_t t_odd = t + kBf16Lanes / 2; step_offsets_even[step_idx] = (t / 16) * 64 + ((t % 16) / 4) * 8; step_offsets_odd[step_idx] = (t_odd / 16) * 64 + ((t_odd % 16) / 4) * 8; } @@ -466,64 +743,20 @@ HWY_INLINE void TileFlashAttentionSVBlock( B_odd3 = hn::LoadU(dbf, v_ptr_next + offset_odd + 32); // Even halves first (A0, A2, A4, A6) - if constexpr (kNumQueries >= 1) { - const auto A0 = hn::LoadU(dbf, q_ptr + 0); - C00 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even0, C00); - C01 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even1, C01); - C02 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even2, C02); - C03 = PerBlock2x2MatMulMaybeEmulate(df, A0, B_even3, C03); - } - if constexpr (kNumQueries >= 3) { - const auto A2 = hn::LoadU(dbf, q_ptr + 2 * kStepSize); - C10 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even0, C10); - C11 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even1, C11); - C12 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even2, C12); - C13 = PerBlock2x2MatMulMaybeEmulate(df, A2, B_even3, C13); - } - if constexpr (kNumQueries >= 5) { - const auto A4 = hn::LoadU(dbf, q_ptr + 4 * kStepSize); - C20 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even0, C20); - C21 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even1, C21); - C22 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even2, C22); - C23 = PerBlock2x2MatMulMaybeEmulate(df, A4, B_even3, C23); - } - if constexpr (kNumQueries >= 7) { - const auto A6 = hn::LoadU(dbf, q_ptr + 6 * kStepSize); - C30 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even0, C30); - C31 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even1, C31); - C32 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even2, C32); - C33 = PerBlock2x2MatMulMaybeEmulate(df, A6, B_even3, C33); - } + auto load_A_even = [&](size_t idx) HWY_ATTR { + return hn::LoadU(dbf, q_ptr + 2 * idx * kBf16Lanes); + }; + Accumulate4x4Grid( + df, load_A_even, B_even0, B_even1, B_even2, B_even3, C00, C01, C02, + C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33); // Odd halves second (A1, A3, A5, A7) - if constexpr (kNumQueries >= 1) { - const auto A1 = hn::LoadU(dbf, q_ptr + kStepSize); - C00 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd0, C00); - C01 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd1, C01); - C02 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd2, C02); - C03 = PerBlock2x2MatMulMaybeEmulate(df, A1, B_odd3, C03); - } - if constexpr (kNumQueries >= 3) { - const auto A3 = hn::LoadU(dbf, q_ptr + 3 * kStepSize); - C10 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd0, C10); - C11 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd1, C11); - C12 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd2, C12); - C13 = PerBlock2x2MatMulMaybeEmulate(df, A3, B_odd3, C13); - } - if constexpr (kNumQueries >= 5) { - const auto A5 = hn::LoadU(dbf, q_ptr + 5 * kStepSize); - C20 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd0, C20); - C21 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd1, C21); - C22 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd2, C22); - C23 = PerBlock2x2MatMulMaybeEmulate(df, A5, B_odd3, C23); - } - if constexpr (kNumQueries >= 7) { - const auto A7 = hn::LoadU(dbf, q_ptr + 7 * kStepSize); - C30 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd0, C30); - C31 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd1, C31); - C32 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd2, C32); - C33 = PerBlock2x2MatMulMaybeEmulate(df, A7, B_odd3, C33); - } + auto load_A_odd = [&](size_t idx) HWY_ATTR { + return hn::LoadU(dbf, q_ptr + (2 * idx + 1) * kBf16Lanes); + }; + Accumulate4x4Grid( + df, load_A_odd, B_odd0, B_odd1, B_odd2, B_odd3, C00, C01, C02, C03, + C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33); } // Reduce accumulators to 128-bit and add scaled old values @@ -605,6 +838,31 @@ HWY_INLINE void TileFlashAttentionSVBlock( } } +template , + typename DBF_T = hn::ScalableTag, typename KV_T> +HWY_INLINE void TileFlashAttentionSVBlock( + size_t q_base_idx, size_t position, size_t current_kv_start_offset, + size_t current_kv_idx, size_t actual_steps, size_t qkv_dim, + const float* HWY_RESTRICT scales_old, + const float* HWY_RESTRICT q_scales_new, + const float* HWY_RESTRICT softmax_buf, + const hwy::Span>& kvs, + float* HWY_RESTRICT C_accumulators, const GroupInfo* group_infos = nullptr, + size_t num_groups = 0, const int8_t* HWY_RESTRICT q_weights_pre = nullptr, + const float* HWY_RESTRICT w_scales_pre = nullptr) { + if constexpr (IsInt8()) { + TileFlashAttentionSVBlockInt8( + q_base_idx, qkv_dim, scales_old, C_accumulators, group_infos, + num_groups, q_weights_pre, w_scales_pre); + } else { + TileFlashAttentionSVBlockBF16( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, softmax_buf, kvs, + C_accumulators); + } +} + template > HWY_INLINE void UpdateOnlineSoftmaxSingleQuery( DF df, float* HWY_RESTRICT q_logits, size_t actual_block_size, @@ -709,8 +967,8 @@ HWY_INLINE void UpdateOnlineSoftmaxSingleQuery( exp_denominator_sums[q] = new_sum; } -template -HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( +template +HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Impl( const hwy::Span> kvs, size_t q_count, const Q_T* HWY_RESTRICT q_base, hwy::Span q_scales, hwy::Span start_pos_per_query, @@ -727,22 +985,17 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( const DBF dbf; const DU du; using VF = hn::Vec; - const size_t step_size = hn::Lanes(dbf); - if (step_size > 32) { - HWY_ABORT( - "Unsupported step size (vector width) %zu. Only up to 512-bit (32 " - "lanes) is supported.", - step_size); - } + constexpr size_t kBf16Lanes = kRegBytes / sizeof(BF16); const float one_over_cap = 1.0f / att_cap; constexpr int kNumQueriesPerLoop = 8; constexpr size_t kTileSize = 32; - constexpr size_t kBlockSize = BENCHMARK_BLOCK_SIZE; + constexpr size_t kBlockSize = + IsInt8() ? BENCHMARK_BLOCK_SIZE_INT8 : BENCHMARK_BLOCK_SIZE_BF16; const size_t kStepsPerTile = - std::max(size_t(1), KVCache::kTileSize / step_size); + std::max(size_t(1), KVCache::kTileSize / kBf16Lanes); - TileAttentionGroupParams preamble(q_count, kNumQueriesPerLoop, step_size, + TileAttentionGroupParams preamble(q_count, kNumQueriesPerLoop, kBf16Lanes, start_pos_per_query, last_pos_per_query, att_cap, att_out); @@ -756,6 +1009,13 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( hwy::AlignedVector C_accumulators(hwy::RoundUpTo(q_count, 8) * qkv_dim, 0.0f); hwy::AlignedVector softmax_buf(q_count * kBlockSize, kMaskedLogitVal); + hwy::AlignedVector q_weights_buf; + hwy::AlignedVector w_scales_buf; + if constexpr (IsInt8()) { + const size_t num_qp = 4; + q_weights_buf.resize((kBlockSize / 8) * num_qp * 16, 0); + w_scales_buf.resize(8, 0.0f); + } size_t current_kv_idx = 0; size_t current_kv_start_offset = 0; @@ -775,25 +1035,48 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( size_t kv_remaining = kv_rows - (position - current_kv_start_offset); size_t actual_block_size = - std::min(kBlockSize, hwy::RoundUpTo(remaining_tokens, step_size)); + std::min(kBlockSize, hwy::RoundUpTo(remaining_tokens, kBf16Lanes)); actual_block_size = std::min(actual_block_size, kv_remaining); - size_t actual_steps = actual_block_size / step_size; - [[maybe_unused]] size_t actual_M = - hwy::DivCeil(actual_block_size, kTileSize); + size_t actual_steps = actual_block_size / kBf16Lanes; + const size_t num_groups = actual_block_size / 8; + HWY_ALIGN GroupInfo group_infos[kBlockSize / 8]; + if constexpr (IsInt8()) { + for (size_t g = 0; g < num_groups; ++g) { + size_t global_token_pos = position + g * 8; + size_t kv_ptr_idx = current_kv_idx; + size_t kv_start_offset = current_kv_start_offset; + while (global_token_pos - kv_start_offset >= + kvs[kv_ptr_idx].Rows() * kTileSize) { + kv_start_offset += kvs[kv_ptr_idx].Rows() * kTileSize; + kv_ptr_idx++; + } + size_t tile_idx = (global_token_pos - kv_start_offset) / kTileSize; + size_t token_in_tile = (global_token_pos - kv_start_offset) % kTileSize; + + const int8_t* tile_base = HWY_RCAST_ALIGNED( + const int8_t*, kvs[kv_ptr_idx].RowBytes(tile_idx)); + const int8_t* v_tile_base = tile_base + qkv_dim * kTileSize; + const BF16* scales = + HWY_RCAST_ALIGNED(const BF16*, v_tile_base + qkv_dim * kTileSize); + const BF16* v_scales = scales + kTileSize; + + group_infos[g] = {v_scales, v_tile_base, token_in_tile}; + } + } size_t macro_tile_start_pos = position; auto inner_loop_qk = [&](size_t query_idx, size_t step_idx) HWY_ATTR { size_t loop_idx = query_idx / kNumQueriesPerLoop; - size_t step_pos = position + step_idx * step_size; - if (step_pos + step_size <= min_start_pos_per_group[loop_idx] || + size_t step_pos = position + step_idx * kBf16Lanes; + if (step_pos + kBf16Lanes <= min_start_pos_per_group[loop_idx] || step_pos > max_last_pos_per_group[loop_idx]) { float* softmax_buf_ptr = - softmax_buf.data() + query_idx * kBlockSize + step_idx * step_size; + softmax_buf.data() + query_idx * kBlockSize + step_idx * kBf16Lanes; for (int q = 0; q < kNumQueries; ++q) { - for (int t = 0; t < step_size; ++t) { + for (int t = 0; t < kBf16Lanes; ++t) { softmax_buf_ptr[q * kBlockSize + t] = kMaskedLogitVal; } } @@ -804,11 +1087,24 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( size_t s_idx = (position - current_kv_start_offset) / KVCache::kTileSize + tile; - const size_t current_pos = macro_tile_start_pos + step_idx * step_size; - const BF16* tile_base = - reinterpret_cast(kvs[current_kv_idx].RowBytes(s_idx)); + const size_t current_pos = macro_tile_start_pos + step_idx * kBf16Lanes; + const KV_T* tile_base = + HWY_RCAST_ALIGNED(const KV_T*, kvs[current_kv_idx].RowBytes(s_idx)); + + const Q_T* q_group = q_base + query_idx * qkv_dim; - const BF16* q_group = q_base + query_idx * qkv_dim; + VF C00 = hn::Zero(df), C01 = hn::Zero(df), C02 = hn::Zero(df), + C03 = hn::Zero(df); + VF C10 = hn::Zero(df), C11 = hn::Zero(df), C12 = hn::Zero(df), + C13 = hn::Zero(df); + VF C20 = hn::Zero(df), C21 = hn::Zero(df), C22 = hn::Zero(df), + C23 = hn::Zero(df); + VF C30 = hn::Zero(df), C31 = hn::Zero(df), C32 = hn::Zero(df), + C33 = hn::Zero(df); + + QDotKTilexUpTo8MatrixAccumulation( + df, q_group, tile_base, current_pos, qkv_dim, qkv_dim, C00, C01, C02, + C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33); VF x_0_p_0 = hn::Zero(df), x_0_p_1 = hn::Zero(df), x_1_p_0 = hn::Zero(df), x_1_p_1 = hn::Zero(df); @@ -819,11 +1115,6 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( VF x_6_p_0 = hn::Zero(df), x_6_p_1 = hn::Zero(df), x_7_p_0 = hn::Zero(df), x_7_p_1 = hn::Zero(df); - VF C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, - C32, C33; - QDotKTilexUpTo8MatrixAccumulation( - df, q_group, tile_base, current_pos, qkv_dim, C00, C01, C02, C03, C10, - C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33); auto pack_queries = [&](VF c_left, VF c_right, VF& x_even, VF& x_odd) HWY_ATTR { using D64 = hn::Repartition; @@ -838,14 +1129,62 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( x_even = hn::BitCast(df, even_scrambled); x_odd = hn::BitCast(df, odd_scrambled); }; - if constexpr (kNumQueries >= 1) pack_queries(C00, C01, x_0_p_0, x_1_p_0); - if constexpr (kNumQueries >= 1) pack_queries(C02, C03, x_0_p_1, x_1_p_1); - if constexpr (kNumQueries >= 3) pack_queries(C10, C11, x_2_p_0, x_3_p_0); - if constexpr (kNumQueries >= 3) pack_queries(C12, C13, x_2_p_1, x_3_p_1); - if constexpr (kNumQueries >= 5) pack_queries(C20, C21, x_4_p_0, x_5_p_0); - if constexpr (kNumQueries >= 5) pack_queries(C22, C23, x_4_p_1, x_5_p_1); - if constexpr (kNumQueries >= 7) pack_queries(C30, C31, x_6_p_0, x_7_p_0); - if constexpr (kNumQueries >= 7) pack_queries(C32, C33, x_6_p_1, x_7_p_1); + + if constexpr (kNumQueries >= 1) { + pack_queries(C00, C01, x_0_p_0, x_1_p_0); + pack_queries(C02, C03, x_0_p_1, x_1_p_1); + } + if constexpr (kNumQueries >= 3) { + pack_queries(C10, C11, x_2_p_0, x_3_p_0); + pack_queries(C12, C13, x_2_p_1, x_3_p_1); + } + if constexpr (kNumQueries >= 5) { + pack_queries(C20, C21, x_4_p_0, x_5_p_0); + pack_queries(C22, C23, x_4_p_1, x_5_p_1); + } + if constexpr (kNumQueries >= 7) { + pack_queries(C30, C31, x_6_p_0, x_7_p_0); + pack_queries(C32, C33, x_6_p_1, x_7_p_1); + } + + if constexpr (IsInt8()) { + const size_t token_in_tile = current_pos % 32; + const BF16* microscaling_scales_k = + HWY_RCAST_ALIGNED(const BF16*, + tile_base + 2 * qkv_dim * kTileSize) + + token_in_tile; + + const hn::Vec v_sk_bf16 = hn::LoadU(dbf, microscaling_scales_k); + const VF v_sk_f_lo = hn::PromoteLowerTo(df, v_sk_bf16); + const VF v_sk_f_hi = hn::PromoteUpperTo(df, v_sk_bf16); + + auto apply_scales = [&](size_t r, VF& p_even_lo, VF& p_odd_lo, + VF& p_even_hi, VF& p_odd_hi) HWY_ATTR { + float sq_ev = q_scales[query_idx + 2 * r]; + float sq_od = (2 * r + 1 < kNumQueries) + ? q_scales[query_idx + 2 * r + 1] + : 0.0f; + + const VF scale_ev_lo = hn::Mul(hn::Set(df, sq_ev), v_sk_f_lo); + const VF scale_od_lo = hn::Mul(hn::Set(df, sq_od), v_sk_f_lo); + const VF scale_ev_hi = hn::Mul(hn::Set(df, sq_ev), v_sk_f_hi); + const VF scale_od_hi = hn::Mul(hn::Set(df, sq_od), v_sk_f_hi); + + p_even_lo = hn::Mul(p_even_lo, scale_ev_lo); + p_odd_lo = hn::Mul(p_odd_lo, scale_od_lo); + p_even_hi = hn::Mul(p_even_hi, scale_ev_hi); + p_odd_hi = hn::Mul(p_odd_hi, scale_od_hi); + }; + + if constexpr (kNumQueries >= 1) + apply_scales(0, x_0_p_0, x_1_p_0, x_0_p_1, x_1_p_1); + if constexpr (kNumQueries >= 3) + apply_scales(1, x_2_p_0, x_3_p_0, x_2_p_1, x_3_p_1); + if constexpr (kNumQueries >= 5) + apply_scales(2, x_4_p_0, x_5_p_0, x_4_p_1, x_5_p_1); + if constexpr (kNumQueries >= 7) + apply_scales(3, x_6_p_0, x_7_p_0, x_6_p_1, x_7_p_1); + } constexpr int kFirstHalfAmountOfQueries = std::min(kNumQueries, 4); constexpr int kSecondHalfAmountOfQueries = @@ -860,7 +1199,7 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( } if (current_pos < max_start_pos_per_group[loop_idx] || - current_pos + step_size - 1 > min_last_pos_per_group[loop_idx]) { + current_pos + kBf16Lanes - 1 > min_last_pos_per_group[loop_idx]) { ApplyMasking( df, du, current_pos, start_pos_per_query.data() + query_idx, last_pos_per_query.data() + query_idx, x_0_p_0, x_0_p_1, x_1_p_0, @@ -869,7 +1208,7 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( } float* softmax_buf_ptr = - softmax_buf.data() + query_idx * kBlockSize + step_idx * step_size; + softmax_buf.data() + query_idx * kBlockSize + step_idx * kBf16Lanes; auto store_logits = [&](const VF& x_p0, const VF& x_p1, size_t q) HWY_ATTR { hn::StoreU(x_p0, df, softmax_buf_ptr + q * kBlockSize + 0); @@ -933,9 +1272,6 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( size_t q = q_base_idx + q_offset; if (position + actual_block_size <= start_pos_per_query[q] || position > last_pos_per_query[q]) { - // Skip update for completely masked query in this block. - // scales_old[q_offset] remains 1.0f, q_scales_new[q_offset] remains - // 0.0f. continue; } float* q_logits = softmax_buf.data() + q * kBlockSize; @@ -943,32 +1279,28 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( df, q_logits, actual_block_size, q, max_logits, exp_denominator_sums, q_offset, scales_old, q_scales_new); } + if constexpr (IsInt8()) { + QuantizeAndPackSoftmaxProbs(df, dbf, q_base_idx, actual_q_count, + actual_block_size, qkv_dim, q_scales_new, + softmax_buf.data(), group_infos, + q_weights_buf.data(), w_scales_buf.data()); + } auto call_sv_block = [&]() HWY_ATTR { - HWY_LANES_CONSTEXPR size_t step_size = hn::Lanes(dbf); - if constexpr (HWY_HAVE_CONSTEXPR_LANES) { - TileFlashAttentionSVBlock( - q_base_idx, position, current_kv_start_offset, current_kv_idx, - actual_steps, qkv_dim, scales_old, q_scales_new, - softmax_buf.data(), kvs, C_accumulators.data()); - } else { - if (step_size == 32) { - TileFlashAttentionSVBlock( - q_base_idx, position, current_kv_start_offset, current_kv_idx, - actual_steps, qkv_dim, scales_old, q_scales_new, - softmax_buf.data(), kvs, C_accumulators.data()); - } else if (step_size == 16) { - TileFlashAttentionSVBlock( - q_base_idx, position, current_kv_start_offset, current_kv_idx, - actual_steps, qkv_dim, scales_old, q_scales_new, - softmax_buf.data(), kvs, C_accumulators.data()); - } else { // step_size == 8 (guaranteed by top validation) - TileFlashAttentionSVBlock( - q_base_idx, position, current_kv_start_offset, current_kv_idx, - actual_steps, qkv_dim, scales_old, q_scales_new, - softmax_buf.data(), kvs, C_accumulators.data()); - } + const GroupInfo* gi_ptr = nullptr; + size_t n_groups = 0; + const int8_t* qw_ptr = nullptr; + const float* ws_ptr = nullptr; + if constexpr (IsInt8()) { + gi_ptr = group_infos; + n_groups = num_groups; + qw_ptr = q_weights_buf.data(); + ws_ptr = w_scales_buf.data(); } + TileFlashAttentionSVBlock( + q_base_idx, position, current_kv_start_offset, current_kv_idx, + actual_steps, qkv_dim, scales_old, q_scales_new, softmax_buf.data(), + kvs, C_accumulators.data(), gi_ptr, n_groups, qw_ptr, ws_ptr); }; if (actual_q_count >= 8) { @@ -1003,6 +1335,39 @@ HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro( } } +template +HWY_ATTR void TileFlashAttentionReturnExpSumsAndMaxLogitsBF16( + const hwy::Span> kvs, size_t q_count, + const Q_T* HWY_RESTRICT q_base, hwy::Span q_scales, + hwy::Span start_pos_per_query, + hwy::Span last_pos_per_query, float att_cap, + MatPtrT& att_out, float* HWY_RESTRICT exp_denominator_sums, + float* HWY_RESTRICT max_logits) { +#if HWY_HAVE_CONSTEXPR_LANES + constexpr size_t kRegBytes = hn::Lanes(hn::ScalableTag()); + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Impl( + kvs, q_count, q_base, q_scales, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); +#else + const size_t reg_bytes = hn::Lanes(hn::ScalableTag()); + if (reg_bytes == 64) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Impl<64, KV_T, Q_T>( + kvs, q_count, q_base, q_scales, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); + } else if (reg_bytes == 32) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Impl<32, KV_T, Q_T>( + kvs, q_count, q_base, q_scales, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); + } else if (reg_bytes == 16) { + TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Impl<16, KV_T, Q_T>( + kvs, q_count, q_base, q_scales, start_pos_per_query, last_pos_per_query, + att_cap, att_out, exp_denominator_sums, max_logits); + } else { + HWY_ABORT("Unsupported register size %zu bytes", reg_bytes); + } +#endif +} + } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 75092b63..36e2aeb6 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -257,9 +257,9 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { if (rel_abs_delta > 0.0f) { rel_abs_delta /= std::max(std::abs(a_row[c]), std::abs(b_row[c])); } - if (rel_abs_delta >= 1e-3) { + if (rel_abs_delta >= 1.5e-3) { if (failures < 5) { - EXPECT_LT(rel_abs_delta, 1e-3) + EXPECT_LT(rel_abs_delta, 1.5e-3) << "a[" << r << "," << c << "]=" << a_row[c] << ", b[" << r << "," << c << "]=" << b_row[c]; } @@ -574,6 +574,19 @@ void RunTiledFlashAttentionTest(gcpp::KVEncoding kv_encoding, hwy::Span(start_pos_per_query), hwy::Span(last_pos_per_query), att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kInt8MatrixAccumulation) { + size_t num_queries_rounded = hwy::RoundUpTo(num_queries, 2); + hwy::AlignedVector int8_queries(num_queries_rounded * qkv_dim); + hwy::AlignedVector q_scales(num_queries_rounded); + CompressAndQuantizeQueriesMatrixAccumulationInt8( + q_all.data(), int8_queries.data(), q_scales.data(), num_queries, + qkv_dim); + + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( + kvs, num_queries, int8_queries.data(), q_scales, + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); } else if (attention_impl == AttentionImpl::kFlashMatrixAccumulation) { size_t num_queries_rounded = hwy::RoundUpTo(num_queries, 2); hwy::AlignedVector bf16_queries(num_queries_rounded * qkv_dim); @@ -688,7 +701,30 @@ void TestTiledFlashAttentionBF16MatrixAccumulation() { tol_exp, tol_max, 0.0f); } -void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { +void TestTiledFlashAttentionInt8MatrixAccumulation() { + const hn::ScalableTag dbf; + if (hn::Lanes(dbf) > 32) { + GTEST_SKIP() << "Skipping MatrixAccumulation test for target with register " + "size > 512-bit."; + return; + } + + // INT8 has slightly larger error due to quantization, so we use slightly + // larger tolerances + const float tol = 1.5e-1f; + const float tol_exp = 2e-1f; + const float tol_max = 5e-2f; + + RunTiledFlashAttentionTest(gcpp::KVEncoding::kInt8MatrixAccumulation, + AttentionImpl::kInt8MatrixAccumulation, + tol, tol_exp, tol_max, 0.0f); +} + +template +void RunTiledFlashAttentionDifferentialTest( + size_t kv_seq_len, float tol, float tol_exp, float tol_max, + gcpp::KVEncoding ref_encoding, gcpp::KVEncoding opt_encoding, + AttentionImpl opt_impl, const char* type_name) { const hn::ScalableTag dbf; if (hn::Lanes(dbf) > 32) { GTEST_SKIP() << "Skipping MatrixAccumulation test for target with register " @@ -697,7 +733,6 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { } size_t qkv_dim = 64; - size_t kv_seq_len = 2048; // number of tokens we will attend to. size_t padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); float att_cap = 0.0f; @@ -709,14 +744,15 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - // Set up reference BF16 - MatStorageT kv_ref( + // Set up reference cache + size_t ref_tile_size_bytes = *gcpp::GetTileSizeBytes(ref_encoding, qkv_dim); + size_t ref_tile_size_in_elements = ref_tile_size_bytes / sizeof(KV_T); + MatStorageT kv_ref( "kv_ref", Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, - 2 * qkv_dim * gcpp::KVCache::kTileSize), + ref_tile_size_in_elements), ctx.allocator, MatPadding::kPacked); - PopulateTestKVCache(kv_ref, gcpp::KVEncoding::kBF16TwoTranspositions, - qkv_dim); + PopulateTestKVCache(kv_ref, ref_encoding, qkv_dim); AlignedFloatVector q_all_ref = PopulateTestQueries(num_queries, qkv_dim); std::vector> bf16_queries_ref(num_queries * @@ -738,22 +774,35 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { max_logits_ref[i] = -std::numeric_limits::max() / 2.0f; } - // Set up Matrix Accumulation - size_t tile_size_bytes = *gcpp::GetTileSizeBytes( - gcpp::KVEncoding::kBF16MatrixAccumulation, qkv_dim); - size_t tile_size_in_elements = tile_size_bytes / sizeof(BF16); - MatStorageT kv("kv", - Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, - tile_size_in_elements), - ctx.allocator, MatPadding::kPacked); - PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16MatrixAccumulation, qkv_dim); + // Set up Optimized cache (Matrix Accumulation) + size_t opt_tile_size_bytes = *gcpp::GetTileSizeBytes(opt_encoding, qkv_dim); + size_t opt_tile_size_in_elements = opt_tile_size_bytes / sizeof(KV_T); + MatStorageT kv( + "kv", + Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, + opt_tile_size_in_elements), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv, opt_encoding, qkv_dim); AlignedFloatVector q_all = PopulateTestQueries(num_queries, qkv_dim); + + hwy::AlignedVector int8_queries; + std::vector> bf16_queries; + hwy::AlignedVector q_scales; + size_t num_queries_rounded = hwy::RoundUpTo(num_queries, 2); - std::vector> bf16_queries( - num_queries_rounded * qkv_dim); - CompressAndTransposeQueriesMatrixAccumulation( - q_all.data(), bf16_queries.data(), num_queries, qkv_dim); + + if (opt_impl == AttentionImpl::kInt8MatrixAccumulation) { + int8_queries.resize(num_queries_rounded * qkv_dim); + q_scales.resize(num_queries_rounded); + CompressAndQuantizeQueriesMatrixAccumulationInt8( + q_all.data(), int8_queries.data(), q_scales.data(), num_queries, + qkv_dim); + } else { + bf16_queries.resize(num_queries_rounded * qkv_dim); + CompressAndTransposeQueriesMatrixAccumulation( + q_all.data(), bf16_queries.data(), num_queries, qkv_dim); + } MatStorageT att_out("att_out", Extents2D(num_queries, qkv_dim), ctx.allocator, MatPadding::kPacked); @@ -782,6 +831,7 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { } } + // Run reference hwy::Span kvs_ref(&kv_ref, 1); DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsBF16( kvs_ref, num_queries, bf16_queries_ref.data(), @@ -789,22 +839,29 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { hwy::Span(last_pos_per_query), att_cap, att_out_ref, exp_denominator_sums_ref.data(), max_logits_ref.data()); + // Run optimized hwy::Span kvs(&kv, 1); - DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( - kvs, num_queries, bf16_queries.data(), - hwy::Span(start_pos_per_query), - hwy::Span(last_pos_per_query), att_cap, att_out, - exp_denominator_sums.data(), max_logits.data()); + if (opt_impl == AttentionImpl::kInt8MatrixAccumulation) { + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( + kvs, num_queries, int8_queries.data(), q_scales, + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + } else { + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation( + kvs, num_queries, bf16_queries.data(), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), att_cap, att_out, + exp_denominator_sums.data(), max_logits.data()); + } + // Verification float max_abs_err = 0.0f; float mse_sum = 0.0f; float dot_prod = 0.0f; float norm_ref = 0.0f; float norm_out = 0.0f; - const float tol_exp = 1e-1f; - const float tol_max = 1e-4f; - size_t failures = 0; for (size_t i = 0; i < num_queries; ++i) { float diff_exp = @@ -813,7 +870,7 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { if (failures < 5) { EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_ref[i], tol_exp) - << "i=" << i; + << "i=" << i << " (Type: " << type_name << ", SeqLen: " << kv_seq_len << ")"; } failures++; } @@ -821,7 +878,8 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { float diff_max = std::abs(max_logits[i] - max_logits_ref[i]); if (diff_max >= tol_max) { if (failures < 5) { - EXPECT_NEAR(max_logits[i], max_logits_ref[i], tol_max) << "i=" << i; + EXPECT_NEAR(max_logits[i], max_logits_ref[i], tol_max) + << "i=" << i << " (Type: " << type_name << ", SeqLen: " << kv_seq_len << ")"; } failures++; } @@ -835,9 +893,10 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { dot_prod += v_ref * v_out; norm_ref += v_ref * v_ref; norm_out += v_out * v_out; - if (diff >= 1e-1f) { + if (diff >= tol) { if (failures < 5) { - EXPECT_NEAR(v_out, v_ref, 1e-1f) << "i=" << i << " j=" << j; + EXPECT_NEAR(v_out, v_ref, tol) + << "i=" << i << " j=" << j << " (Type: " << type_name << ", SeqLen: " << kv_seq_len << ")"; } failures++; } @@ -845,13 +904,27 @@ void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { } float cosine_sim = dot_prod / (std::sqrt(norm_ref) * std::sqrt(norm_out)); float mse = mse_sum / (num_queries * qkv_dim); - std::cerr << "=== Numerical Verification Results (Q:32, KV:2048) ===\n" + std::cerr << "=== Numerical Verification Results (" << type_name + << ", SeqLen: " << kv_seq_len << ") ===\n" << " Cosine Similarity: " << cosine_sim << "\n" << " Max Absolute Error: " << max_abs_err << "\n" << " Mean Squared Error: " << mse << "\n"; - if (HWY_NATIVE_PER_BLOCK_2X2_MATMUL_BF16) { - std::cerr << " Using native PerBlock2x2MatMul\n"; - } +} + +void TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification() { + RunTiledFlashAttentionDifferentialTest( + 2048, 1.0e-1f, 1.1e-1f, 1e-4f, + gcpp::KVEncoding::kBF16TwoTranspositions, + gcpp::KVEncoding::kBF16MatrixAccumulation, + AttentionImpl::kFlashMatrixAccumulation, "BF16"); +} + +void TestTiledFlashAttentionInt8MatrixAccumulationLargeVerification() { + RunTiledFlashAttentionDifferentialTest( + 1024, 1.5e-1f, 5.0, 8.0e-2f, + gcpp::KVEncoding::kInt8TwoTranspositions, + gcpp::KVEncoding::kInt8MatrixAccumulation, + AttentionImpl::kInt8MatrixAccumulation, "Int8"); } // NOLINTNEXTLINE(google-readability-namespace-comments) @@ -873,9 +946,15 @@ HWY_EXPORT_AND_TEST_P(FlashAttentionTest, HWY_EXPORT_AND_TEST_P( FlashAttentionTest, TestTiledFlashAttentionBF16MatrixAccumulationLargeVerification); +HWY_EXPORT_AND_TEST_P( + FlashAttentionTest, + TestTiledFlashAttentionInt8MatrixAccumulationLargeVerification); + HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8BF16); HWY_EXPORT_AND_TEST_P(FlashAttentionTest, TestTiledFlashAttentionInt8Int16); +HWY_EXPORT_AND_TEST_P(FlashAttentionTest, + TestTiledFlashAttentionInt8MatrixAccumulation); HWY_AFTER_TEST(); diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index e0f9bdc8..5f3e20f8 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -80,7 +80,8 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, if (runtime_config.attention_impl == AttentionImpl::kFlashTransposedQs || runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsInt16 || runtime_config.attention_impl == AttentionImpl::kFlashTransposedQsBF16 || - runtime_config.attention_impl == AttentionImpl::kFlashMatrixAccumulation + runtime_config.attention_impl == AttentionImpl::kFlashMatrixAccumulation || + runtime_config.attention_impl == AttentionImpl::kInt8MatrixAccumulation ) { // clang-format on const size_t num_tiles = @@ -95,12 +96,14 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, ) { kv_cache_type = runtime_config.kv_cache_type.value_or(Type::kBF16); } else if (runtime_config.attention_impl == - AttentionImpl::kFlashTransposedQsInt16) { + AttentionImpl::kFlashTransposedQsInt16 || + runtime_config.attention_impl == + AttentionImpl::kInt8MatrixAccumulation) { if (runtime_config.kv_cache_type.has_value() && runtime_config.kv_cache_type.value() != Type::kInt8) { HWY_WARN( "You are have set kv_cache_type to %s, but you are using " - "FlashTransposedQsInt16 attention implementation which only " + "an attention implementation which only " "supports Int8. kv_cache_type will be set to Int8.", runtime_config.kv_cache_type.value()); } @@ -132,6 +135,9 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, if (runtime_config.attention_impl == AttentionImpl::kFlashMatrixAccumulation) { compact_kv_cache_ptr.SetLayout(MatPtr::Layout::kBF16MatrixAccumulation); + } else if (runtime_config.attention_impl == + AttentionImpl::kInt8MatrixAccumulation) { + compact_kv_cache_ptr.SetLayout(MatPtr::Layout::kInt8MatrixAccumulation); } compact_kv_cache.AllocateFor(compact_kv_cache_ptr, allocator, MatPadding::kPacked); @@ -150,6 +156,9 @@ KVCache::KVCache(const ModelConfig& config, const InferenceArgs& inference_args, if (runtime_config.attention_impl == AttentionImpl::kFlashMatrixAccumulation) { kv_ptr.SetLayout(MatPtr::Layout::kBF16MatrixAccumulation); + } else if (runtime_config.attention_impl == + AttentionImpl::kInt8MatrixAccumulation) { + kv_ptr.SetLayout(MatPtr::Layout::kInt8MatrixAccumulation); } kv_head_ptrs.emplace_back(std::move(kv_ptr)); total_num_tiles += num_tiles_per_kv_head; diff --git a/gemma/kv_transcoding.cc b/gemma/kv_transcoding.cc index da635e5d..865c5998 100644 --- a/gemma/kv_transcoding.cc +++ b/gemma/kv_transcoding.cc @@ -24,6 +24,7 @@ std::optional GetTileSizeBytes(gcpp::KVEncoding encoding, switch (encoding) { case gcpp::KVEncoding::kInt8: case gcpp::KVEncoding::kInt8TwoTranspositions: + case gcpp::KVEncoding::kInt8MatrixAccumulation: return qkv_dim * kTileSize * 2 * sizeof(int8_t) + kTileSize * 2 * sizeof(gcpp::KV_microscale_t); case gcpp::KVEncoding::kBF16: @@ -169,6 +170,58 @@ void EncodeTileBF16MatrixAccumulation(size_t qkv_dim, } } +void EncodeTileInt8MatrixAccumulation(size_t qkv_dim, + const DecodedTile& decoded, + hwy::Span out_encoded_tile_data) { + HWY_DASSERT(qkv_dim % 8 == 0); + int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data()); + int8_t* v_data = k_data + qkv_dim * kTileSize; + gcpp::KV_microscale_t* scales = + HWY_RCAST_ALIGNED(gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); + gcpp::KV_microscale_t* k_scales = scales; + gcpp::KV_microscale_t* v_scales = scales + kTileSize; + + AlignedFloatVector k_max_abs(kTileSize, 0.0f); + AlignedFloatVector v_max_abs(kTileSize, 0.0f); + + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + k_max_abs[token] = + std::max(k_max_abs[token], std::abs(decoded.k_elem(token, dim))); + v_max_abs[token] = + std::max(v_max_abs[token], std::abs(decoded.v_elem(token, dim))); + } + } + + AlignedFloatVector inv_scales_k(kTileSize); + AlignedFloatVector inv_scales_v(kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + float scale_k = k_max_abs[token] == 0.0f ? 1.0f : k_max_abs[token] / 127.0f; + k_scales[token] = hwy::ConvertScalarTo(scale_k); + float decoded_k = hwy::ConvertScalarTo(k_scales[token]); + inv_scales_k[token] = decoded_k == 0.0f ? 0.0f : 1.0f / decoded_k; + + float scale_v = v_max_abs[token] == 0.0f ? 1.0f : v_max_abs[token] / 127.0f; + v_scales[token] = hwy::ConvertScalarTo(scale_v); + float decoded_v = hwy::ConvertScalarTo(v_scales[token]); + inv_scales_v[token] = decoded_v == 0.0f ? 0.0f : 1.0f / decoded_v; + } + + // 2. Quantize and pack K and V + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t k_offset = MatrixAccumulationOffset_Int8(qkv_dim, dim, token); + k_data[k_offset] = + Quantize(decoded.k_elem(token, dim), inv_scales_k[token]); + + // V transposed layout (channel-major within 2-token blocks) + size_t v_offset = VMatrixAccumulationOffset_Int8(qkv_dim, token, dim); + v_data[v_offset] = + Quantize(decoded.v_elem(token, dim), inv_scales_v[token]); + } + } +} + void EncodeTileInt8(bool transposed, size_t qkv_dim, const DecodedTile& decoded, hwy::Span out_encoded_tile_data) { int8_t* k_data = HWY_RCAST_ALIGNED(int8_t*, out_encoded_tile_data.data()); @@ -292,6 +345,33 @@ void DecodeTileBF16MatrixAccumulation(size_t qkv_dim, } } +void DecodeTileInt8MatrixAccumulation(size_t qkv_dim, + hwy::Span encoded_tile_data, + DecodedTile* out) { + HWY_DASSERT(qkv_dim % 8 == 0); + const int8_t* k_data = + HWY_RCAST_ALIGNED(const int8_t*, encoded_tile_data.data()); + const int8_t* v_data = k_data + qkv_dim * kTileSize; + const gcpp::KV_microscale_t* scales = HWY_RCAST_ALIGNED( + const gcpp::KV_microscale_t*, v_data + kTileSize * qkv_dim); + const gcpp::KV_microscale_t* k_scales = scales; + const gcpp::KV_microscale_t* v_scales = scales + kTileSize; + + for (size_t token = 0; token < kTileSize; ++token) { + float scale_k = hwy::ConvertScalarTo(k_scales[token]); + float scale_v = hwy::ConvertScalarTo(v_scales[token]); + + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t k_offset = MatrixAccumulationOffset_Int8(qkv_dim, dim, token); + out->k_elem(token, dim) = k_data[k_offset] * scale_k; + + // V transposed layout (channel-major within 2-token blocks) + size_t v_offset = VMatrixAccumulationOffset_Int8(qkv_dim, token, dim); + out->v_elem(token, dim) = v_data[v_offset] * scale_v; + } + } +} + void DecodeTileInt8(bool transposed, size_t qkv_dim, hwy::Span encoded_tile_data, DecodedTile* out) { const int8_t* k_data = @@ -363,6 +443,10 @@ bool DecodeTile(KVEncoding encoding, hwy::Span encoded_tile_data, DecodeTileInt8(transposed, qkv_dim, encoded_tile_data, out); return true; } + case gcpp::KVEncoding::kInt8MatrixAccumulation: { + DecodeTileInt8MatrixAccumulation(qkv_dim, encoded_tile_data, out); + return true; + } default: return false; } @@ -397,6 +481,10 @@ bool EncodeTile(gcpp::KVEncoding encoding, const DecodedTile& decoded, EncodeTileInt8(transposed, qkv_dim, decoded, out_encoded_tile_data); return true; } + case gcpp::KVEncoding::kInt8MatrixAccumulation: { + EncodeTileInt8MatrixAccumulation(qkv_dim, decoded, out_encoded_tile_data); + return true; + } default: return false; } diff --git a/gemma/kv_transcoding.h b/gemma/kv_transcoding.h index 7abe4065..4fef0018 100644 --- a/gemma/kv_transcoding.h +++ b/gemma/kv_transcoding.h @@ -90,6 +90,27 @@ inline size_t VMatrixAccumulationOffset_BF16(size_t qkv_dim, size_t token, return sub_block * 32 + block_offset; } +inline size_t MatrixAccumulationOffset_Int8(size_t qkv_dim, size_t dim, size_t token) { + const size_t tg8 = token / 8; + const size_t t_in_g = token % 8; + const size_t ch_g = dim / 8; + const size_t ch_in_g = dim % 8; + + const size_t block_start = (ch_g * 4 + tg8) * 64; + return block_start + t_in_g * 8 + ch_in_g; +} + +inline size_t VMatrixAccumulationOffset_Int8(size_t qkv_dim, size_t token, + size_t dim) { + const size_t tg8 = token / 8; + const size_t t_in_g = token % 8; + const size_t ch_g = dim / 8; + const size_t ch_in_g = dim % 8; + + const size_t block_start = (ch_g * 4 + tg8) * 64; + return block_start + ch_in_g * 8 + t_in_g; +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_KV_TRANSCODING_H_ diff --git a/gemma/kv_transcoding_test.cc b/gemma/kv_transcoding_test.cc index 72f114e8..c5c93cd2 100644 --- a/gemma/kv_transcoding_test.cc +++ b/gemma/kv_transcoding_test.cc @@ -1,5 +1,6 @@ #include "gemma/kv_transcoding.h" +#include #include #include #include @@ -28,6 +29,13 @@ struct EncodingTestCase { class KVEncodingTest : public TestWithParam {}; +int8_t Quantize(float v, float inv_scale) { + float scaled = std::nearbyint(v * inv_scale); + if (scaled > 127.0f) return 127; + if (scaled < -127.0f) return -127; + return hwy::ConvertScalarTo(scaled); +} + TEST_P(KVEncodingTest, EncodeDecodeRoundTrip) { const auto& param = GetParam(); constexpr size_t kTileSize = 32; @@ -47,9 +55,9 @@ TEST_P(KVEncodingTest, EncodeDecodeRoundTrip) { std::optional tile_size_bytes = GetTileSizeBytes(param.encoding, qkv_dim); - ASSERT_TRUE(tile_size_bytes.has_value()); + HWY_ASSERT(tile_size_bytes.has_value()); - std::vector encoded(*tile_size_bytes, 0); + hwy::AlignedVector encoded(*tile_size_bytes, 0); EXPECT_TRUE(EncodeTile(param.encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); @@ -70,11 +78,11 @@ TEST_P(KVEncodingTest, SizeChecks) { DecodedTile decoded(qkv_dim, kTileSize); std::optional required_size_or = GetTileSizeBytes(param.encoding, qkv_dim); - ASSERT_TRUE(required_size_or.has_value()); + HWY_ASSERT(required_size_or.has_value()); size_t required_size = *required_size_or; if (required_size > 0) { - std::vector too_small_encoded(required_size - 1, 0); + hwy::AlignedVector too_small_encoded(required_size - 1, 0); EXPECT_FALSE(EncodeTile( param.encoding, decoded, qkv_dim, hwy::Span(too_small_encoded.data(), too_small_encoded.size()))); @@ -93,7 +101,8 @@ INSTANTIATE_TEST_SUITE_P( EncodingTestCase{gcpp::KVEncoding::kBF16TwoTranspositions, 0.05f}, EncodingTestCase{gcpp::KVEncoding::kBF16MatrixAccumulation, 0.05f}, EncodingTestCase{gcpp::KVEncoding::kInt8, 0.1f}, - EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f})); + EncodingTestCase{gcpp::KVEncoding::kInt8TwoTranspositions, 0.1f}, + EncodingTestCase{gcpp::KVEncoding::kInt8MatrixAccumulation, 0.02f})); TEST(KVEncodingTest, ConvertTileFloat32ToBfloat16) { constexpr size_t kTileSize = 32; @@ -113,8 +122,8 @@ TEST(KVEncodingTest, ConvertTileFloat32ToBfloat16) { size_t src_size = GetTileSizeBytes(src_encoding, qkv_dim).value(); size_t dst_size = GetTileSizeBytes(dst_encoding, qkv_dim).value(); - std::vector src_data(src_size); - std::vector dst_data(dst_size); + hwy::AlignedVector src_data(src_size); + hwy::AlignedVector dst_data(dst_size); EXPECT_TRUE(EncodeTile(src_encoding, original, qkv_dim, hwy::Span(src_data.data(), src_data.size()))); @@ -158,8 +167,8 @@ TEST(KVEncodingTest, PairwiseConversion) { size_t src_size = GetTileSizeBytes(src, qkv_dim).value(); size_t dst_size = GetTileSizeBytes(dst, qkv_dim).value(); - std::vector src_data(src_size); - std::vector dst_data(dst_size); + hwy::AlignedVector src_data(src_size); + hwy::AlignedVector dst_data(dst_size); ASSERT_TRUE(EncodeTile(src, original, qkv_dim, hwy::Span(src_data.data(), src_data.size()))) @@ -207,12 +216,12 @@ TEST(KVEncodingTest, LayoutValidationF32) { } size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); - std::vector encoded(size); + hwy::AlignedVector encoded(size); ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); - const float* data = reinterpret_cast(encoded.data()); + const float* data = HWY_RCAST_ALIGNED(const float*, encoded.data()); // K should be row-major [qkv_dim, tile_size] EXPECT_EQ(data[0], 1.0f); // d=0, t=0 @@ -245,12 +254,12 @@ TEST(KVEncodingTest, LayoutValidationF32TwoTranspositions) { } size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); - std::vector encoded(size); + hwy::AlignedVector encoded(size); ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); - const float* data = reinterpret_cast(encoded.data()); + const float* data = HWY_RCAST_ALIGNED(const float*, encoded.data()); // K transposed: [qkv_dim/2, tile_size, 2] EXPECT_EQ(data[0], 1.0f); // d=0, t=0 @@ -287,12 +296,12 @@ TEST(KVEncodingTest, LayoutValidationInt8) { } size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); - std::vector encoded(size); + hwy::AlignedVector encoded(size); ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); - const int8_t* data = reinterpret_cast(encoded.data()); + const int8_t* data = HWY_RCAST_ALIGNED(const int8_t*, encoded.data()); // K should be row-major [qkv_dim, tile_size] // K[3,0] = 97. Max for t=0 is 97. Scale = 97/127. @@ -327,12 +336,12 @@ TEST(KVEncodingTest, LayoutValidationInt8TwoTranspositions) { } size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); - std::vector encoded(size); + hwy::AlignedVector encoded(size); ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); - const int8_t* data = reinterpret_cast(encoded.data()); + const int8_t* data = HWY_RCAST_ALIGNED(const int8_t*, encoded.data()); // K transposed: [qkv_dim/2, tile_size, 2] // K[0,0] = 1. Max for t=0 is 97. Scale = 97/127. @@ -373,12 +382,12 @@ TEST(KVEncodingTest, LayoutValidationBF16MatrixAccumulation) { } size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); - std::vector encoded(size); + hwy::AlignedVector encoded(size); ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, hwy::Span(encoded.data(), encoded.size()))); - const gcpp::BF16* data = reinterpret_cast(encoded.data()); + const gcpp::BF16* data = HWY_RCAST_ALIGNED(const gcpp::BF16*, encoded.data()); // K Layout (8x4 block, token-major) // base_offset = ch_g * 128 + g * 32. @@ -428,5 +437,100 @@ TEST(KVEncodingTest, LayoutValidationBF16MatrixAccumulation) { EXPECT_NEAR(hwy::ConvertScalarTo(data[v_start + 47]), 160.0f, 0.05f); } +TEST(KVEncodingTest, LayoutValidationInt8MatrixAccumulation) { + constexpr size_t kTileSize = 32; + constexpr size_t qkv_dim = 16; + gcpp::KVEncoding encoding = gcpp::KVEncoding::kInt8MatrixAccumulation; + + DecodedTile original(qkv_dim, kTileSize); + for (size_t token = 0; token < kTileSize; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + original.k_elem(token, dim) = (dim + 1) * (token + 1) * 0.1f; + original.v_elem(token, dim) = (dim + 1) * (token + 1) * 0.2f; + } + } + + size_t size = GetTileSizeBytes(encoding, qkv_dim).value(); + hwy::AlignedVector encoded(size); + + ASSERT_TRUE(EncodeTile(encoding, original, qkv_dim, + hwy::Span(encoded.data(), encoded.size()))); + + const int8_t* k_data = HWY_RCAST_ALIGNED(const int8_t*, encoded.data()); + const int8_t* v_data = k_data + qkv_dim * kTileSize; + const gcpp::BF16* scales = + HWY_RCAST_ALIGNED(const gcpp::BF16*, v_data + kTileSize * qkv_dim); + const gcpp::BF16* k_scales = scales; + const gcpp::BF16* v_scales = scales + kTileSize; + + // 1. Verify quantized values and layout offsets + for (size_t token = 0; token < kTileSize; ++token) { + // Compute expected scale for K (across all channels) + float max_abs_k = 0.0f; + for (size_t dim = 0; dim < qkv_dim; ++dim) { + max_abs_k = std::max(max_abs_k, std::abs(original.k_elem(token, dim))); + } + float scale_k_raw = max_abs_k == 0.0f ? 1.0f : max_abs_k / 127.0f; + gcpp::BF16 scale_k_bf16 = hwy::ConvertScalarTo(scale_k_raw); + float scale_k = hwy::ConvertScalarTo(scale_k_bf16); + float inv_scale_k = scale_k == 0.0f ? 0.0f : 1.0f / scale_k; + + // Compute expected scale for V (across all channels) + float max_abs_v = 0.0f; + for (size_t dim = 0; dim < qkv_dim; ++dim) { + max_abs_v = std::max(max_abs_v, std::abs(original.v_elem(token, dim))); + } + float scale_v_raw = max_abs_v == 0.0f ? 1.0f : max_abs_v / 127.0f; + gcpp::BF16 scale_v_bf16 = hwy::ConvertScalarTo(scale_v_raw); + float scale_v = hwy::ConvertScalarTo(scale_v_bf16); + float inv_scale_v = scale_v == 0.0f ? 0.0f : 1.0f / scale_v; + + // Verify scale storage (flat token-major) + EXPECT_NEAR(hwy::ConvertScalarTo(k_scales[token]), scale_k, 1e-5f) + << "K scale mismatch at token=" << token; + EXPECT_NEAR(hwy::ConvertScalarTo(v_scales[token]), scale_v, 1e-5f) + << "V scale mismatch at token=" << token; + + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t expected_k_offset = + MatrixAccumulationOffset_Int8(qkv_dim, dim, token); + size_t expected_v_offset = + VMatrixAccumulationOffset_Int8(qkv_dim, token, dim); + + int8_t expected_k = Quantize(original.k_elem(token, dim), inv_scale_k); + int8_t expected_v = Quantize(original.v_elem(token, dim), inv_scale_v); + + EXPECT_EQ(k_data[expected_k_offset], expected_k) + << "K quantized value mismatch at token=" << token << ", dim=" << dim + << ", expected_k_offset=" << expected_k_offset; + EXPECT_EQ(v_data[expected_v_offset], expected_v) + << "V quantized value mismatch at token=" << token << ", dim=" << dim + << ", expected_v_offset=" << expected_v_offset; + } + } + + // 2. Verify round-trip decoding + DecodedTile decoded(qkv_dim, kTileSize); + ASSERT_TRUE(DecodeTile(encoding, + hwy::Span(encoded.data(), encoded.size()), + qkv_dim, &decoded)); + + for (size_t token = 0; token < kTileSize; ++token) { + float scale_k = hwy::ConvertScalarTo(k_scales[token]); + float scale_v = hwy::ConvertScalarTo(v_scales[token]); + + // Max absolute quantization error is scale * 0.5 (plus epsilon for float + // precision) + for (size_t dim = 0; dim < qkv_dim; ++dim) { + EXPECT_NEAR(decoded.k_elem(token, dim), original.k_elem(token, dim), + scale_k * 0.501f) + << "Decoded K mismatch at token=" << token << ", dim=" << dim; + EXPECT_NEAR(decoded.v_elem(token, dim), original.v_elem(token, dim), + scale_v * 0.501f) + << "Decoded V mismatch at token=" << token << ", dim=" << dim; + } + } +} + } // namespace } // namespace gcpp diff --git a/gemma/tiled_attention.cc b/gemma/tiled_attention.cc index d3b39174..59fab29e 100644 --- a/gemma/tiled_attention.cc +++ b/gemma/tiled_attention.cc @@ -421,8 +421,8 @@ HWY_INLINE void CompressAndTransposeQueriesMatrixAccumulationImpl( HWY_DASSERT(qkv_dim % 4 == 0); namespace hn = hwy::HWY_NAMESPACE; - const hn::FixedTag df; - const hn::FixedTag dbf16; + const hn::Full128 df; + const hn::Full128 dbf16; constexpr size_t kL = 4; size_t p = 0; @@ -468,6 +468,121 @@ void CompressAndTransposeQueriesMatrixAccumulationNonContiguous( qkv_dim); } +template +HWY_INLINE void CompressAndQuantizeQueriesMatrixAccumulationInt8Impl( + QueryProvider query_provider, int8_t* HWY_RESTRICT packed_queries, + float* HWY_RESTRICT packed_scales, size_t num_queries, size_t qkv_dim) { + HWY_DASSERT(qkv_dim % 8 == 0); + + namespace hn = hwy::HWY_NAMESPACE; + const hn::Full128 df; + const hn::Full128 di16; + const hn::Full128 di8; + + using V_F32 = hn::Vec; + using V_I32 = hn::Vec>; + using V_I16 = hn::Vec; + using V_I8 = hn::Vec; + + size_t p = 0; + for (; p < num_queries / 2; ++p) { + const float* q0 = query_provider(2 * p); + const float* q1 = query_provider(2 * p + 1); + int8_t* out = packed_queries + 2 * p * qkv_dim; + float* out_scale0 = packed_scales + 2 * p; + float* out_scale1 = packed_scales + (2 * p + 1); + + // 1. Compute single scale per query over the entire qkv_dim + float max_abs_q0 = AbsMaxOfSpan(hwy::Span(q0, qkv_dim)); + float max_abs_q1 = AbsMaxOfSpan(hwy::Span(q1, qkv_dim)); + + float scale0_raw = max_abs_q0 == 0.0f ? 1.0f : max_abs_q0 / 127.0f; + float scale1_raw = max_abs_q1 == 0.0f ? 1.0f : max_abs_q1 / 127.0f; + + gcpp::KV_microscale_t scale0_bf16 = + hwy::ConvertScalarTo(scale0_raw); + gcpp::KV_microscale_t scale1_bf16 = + hwy::ConvertScalarTo(scale1_raw); + + float scale0 = hwy::ConvertScalarTo(scale0_bf16); + float scale1 = hwy::ConvertScalarTo(scale1_bf16); + + *out_scale0 = scale0; + *out_scale1 = scale1; + + V_F32 inv_scale0 = hn::Set(df, 1.0f / scale0); + V_F32 inv_scale1 = hn::Set(df, 1.0f / scale1); + + for (size_t d = 0; d < qkv_dim; d += 8) { + // 2. Load and quantize Q0 (8 channels) + V_F32 q0_L = hn::LoadU(df, q0 + d); + V_F32 q0_H = hn::LoadU(df, q0 + d + 4); + V_I32 q0_L_scaled = hn::NearestInt(hn::Mul(q0_L, inv_scale0)); + V_I32 q0_H_scaled = hn::NearestInt(hn::Mul(q0_H, inv_scale0)); + V_I16 q0_i16 = hn::OrderedDemote2To(di16, q0_L_scaled, q0_H_scaled); + + // 3. Load and quantize Q1 (8 channels) + V_F32 q1_L = hn::LoadU(df, q1 + d); + V_F32 q1_H = hn::LoadU(df, q1 + d + 4); + V_I32 q1_L_scaled = hn::NearestInt(hn::Mul(q1_L, inv_scale1)); + V_I32 q1_H_scaled = hn::NearestInt(hn::Mul(q1_H, inv_scale1)); + V_I16 q1_i16 = hn::OrderedDemote2To(di16, q1_L_scaled, q1_H_scaled); + + // 4. Pack in pairs at 128-bit boundary: 8 elements of Q0, then 8 elements + // of Q1 + V_I8 packed = hn::OrderedDemote2To(di8, q0_i16, q1_i16); + hn::StoreU(packed, di8, out + d * 2); + } + } + + if (num_queries % 2 != 0) { + const float* q0 = query_provider(2 * p); + int8_t* out = packed_queries + 2 * p * qkv_dim; + float* out_scale0 = packed_scales + 2 * p; + V_I16 zero_i16 = hn::Zero(di16); + + float max_abs_q0 = AbsMaxOfSpan(hwy::Span(q0, qkv_dim)); + + float scale0_raw = max_abs_q0 == 0.0f ? 1.0f : max_abs_q0 / 127.0f; + gcpp::KV_microscale_t scale0_bf16 = + hwy::ConvertScalarTo(scale0_raw); + float scale0 = hwy::ConvertScalarTo(scale0_bf16); + + *out_scale0 = scale0; + + V_F32 inv_scale0 = hn::Set(df, 1.0f / scale0); + + for (size_t d = 0; d < qkv_dim; d += 8) { + V_F32 q0_L = hn::LoadU(df, q0 + d); + V_F32 q0_H = hn::LoadU(df, q0 + d + 4); + V_I32 q0_L_scaled = hn::NearestInt(hn::Mul(q0_L, inv_scale0)); + V_I32 q0_H_scaled = hn::NearestInt(hn::Mul(q0_H, inv_scale0)); + V_I16 q0_i16 = hn::OrderedDemote2To(di16, q0_L_scaled, q0_H_scaled); + + V_I8 packed = hn::OrderedDemote2To(di8, q0_i16, zero_i16); + hn::StoreU(packed, di8, out + d * 2); + } + } +} + +void CompressAndQuantizeQueriesMatrixAccumulationInt8(const float* raw_queries, + int8_t* packed_queries, + float* packed_scales, + size_t num_queries, + size_t qkv_dim) { + CompressAndQuantizeQueriesMatrixAccumulationInt8Impl( + [&](size_t idx) { return raw_queries + idx * qkv_dim; }, packed_queries, + packed_scales, num_queries, qkv_dim); +} + +void CompressAndQuantizeQueriesMatrixAccumulationInt8NonContiguous( + hwy::Span input, int8_t* packed_queries, + float* packed_scales, size_t qkv_dim) { + CompressAndQuantizeQueriesMatrixAccumulationInt8Impl( + [&](size_t idx) { return input[idx]; }, packed_queries, packed_scales, + input.size(), qkv_dim); +} + // clang-format off // Schedules TiledFlashAttention for all heads, tokens and batch. // Returns partial results in the same order as queries in `activations.q`. @@ -539,6 +654,11 @@ void LocalAttentionForAllHeadsTokensAndBatch( activations.int16_queries->size()) { activations.int16_queries->resize(num_sub_tasks * num_queries * qkv_dim); } + if (activations.int8_queries != nullptr && + num_sub_tasks * num_queries * qkv_dim > + activations.int8_queries->size()) { + activations.int8_queries->resize(num_sub_tasks * num_queries * qkv_dim); + } if (activations.float_queries != nullptr && num_sub_tasks * num_queries * qkv_dim > activations.float_queries->size()) { @@ -714,6 +834,24 @@ void LocalAttentionForAllHeadsTokensAndBatch( hwy::Span(last_pos_per_query), activations.config.att_cap, att_out, exp_denominator_sums.data(), max_logits.data()); + } else if (attention_impl == AttentionImpl::kInt8MatrixAccumulation) { + HWY_DASSERT(activations.int8_queries != nullptr); + HWY_DASSERT(activations.q_scales != nullptr); + int8_t* int8_queries_ptr = activations.int8_queries->data() + + task_idx * num_queries * qkv_dim; + float* q_scales_ptr = + activations.q_scales->data() + task_idx * num_queries; + + CompressAndQuantizeQueriesMatrixAccumulationInt8NonContiguous( + queries_ptrs_span, int8_queries_ptr, q_scales_ptr, qkv_dim); + + DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( + kv_ptrs, num_queries, int8_queries_ptr, + hwy::Span(q_scales_ptr, num_queries), + hwy::Span(start_pos_per_query), + hwy::Span(last_pos_per_query), + activations.config.att_cap, att_out, exp_denominator_sums.data(), + max_logits.data()); } else { HWY_DASSERT(activations.float_queries != nullptr); float* contiguous_queries_ptr = activations.float_queries->data() + diff --git a/gemma/tiled_attention.h b/gemma/tiled_attention.h index e8524a72..5a08efba 100644 --- a/gemma/tiled_attention.h +++ b/gemma/tiled_attention.h @@ -47,6 +47,9 @@ namespace gcpp { BF16* packed_queries, \ size_t num_queries, \ size_t qkv_dim); \ + void CompressAndQuantizeQueriesMatrixAccumulationInt8( \ + const float* raw_queries, int8_t* packed_queries, float* packed_scales, \ + size_t num_queries, size_t qkv_dim); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 753570ee..2898f528 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -1794,16 +1794,21 @@ HWY_API VI32 PerBlock2x2MatMulMaybeEmulate(DN dn, VI8 a, VI8 b, VI32 c) { #if HWY_NATIVE_PER_BLOCK_2X2_MATMUL_INT8 return hn::PerBlock2x2MatMul(dn, a, b, c); #else - const hn::Repartition di8; + using DA = hn::DFromV; + const DA da; + constexpr size_t kMaxA = hn::MaxLanes(da); constexpr size_t kMaxN = hn::MaxLanes(dn); - HWY_LANES_CONSTEXPR size_t N = hn::Lanes(dn); - HWY_ALIGN int8_t in_a[kMaxN * 4]; - HWY_ALIGN int8_t in_b[kMaxN * 4]; + constexpr size_t kBufSize = (kMaxA > kMaxN * 4) ? kMaxA : (kMaxN * 4); + using T_IN = hn::TFromD; + HWY_ALIGN T_IN in_a[kBufSize]; + HWY_ALIGN T_IN in_b[kBufSize]; HWY_ALIGN int32_t expected[kMaxN]; - hn::Store(a, di8, in_a); - hn::Store(b, di8, in_b); + hn::Store(a, da, in_a); + hn::Store(b, da, in_b); hn::Store(c, dn, expected); + HWY_LANES_CONSTEXPR size_t N = hn::Lanes(dn); + for (size_t block = 0; block < N; block += 4) { const size_t block_i8 = block * 4; for (int i = 0; i < 2; ++i) { diff --git a/util/mat.h b/util/mat.h index b9bfb72c..9f022f05 100644 --- a/util/mat.h +++ b/util/mat.h @@ -70,6 +70,7 @@ class MatPtr : public IFields { enum class Layout { kFlat, kBF16MatrixAccumulation, + kInt8MatrixAccumulation, }; Layout GetLayout() const { return layout_; }