Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ cc_library(
":configs",
":flash_structs",
":kv_cache",
":kv_transcoding",
":mat",
":matmul",
":matmul_env",
Expand Down
4 changes: 4 additions & 0 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ struct AttentionActivations {
size_t qkv_dim;
AlignedBF16Vector bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>> int16_queries;
hwy::AlignedVector<int8_t> int8_queries;
AlignedFloatVector float_queries;
AlignedFloatVector q_scales;

Expand Down Expand Up @@ -202,6 +203,7 @@ struct AttentionActivationsPtrs {
split_flash_params(split_flash_params),
bf16_queries(nullptr),
int16_queries(nullptr),
int8_queries(nullptr),
float_queries(nullptr),
q_scales(nullptr),
div_seq_len(static_cast<uint32_t>(seq_len)),
Expand All @@ -227,6 +229,7 @@ struct AttentionActivationsPtrs {
inv_timescale_global = activations.inv_timescale_global;
bf16_queries = &activations.bf16_queries;
int16_queries = &activations.int16_queries;
int8_queries = &activations.int8_queries;
float_queries = &activations.float_queries;
q_scales = &activations.q_scales;
}
Expand Down Expand Up @@ -296,6 +299,7 @@ struct AttentionActivationsPtrs {
sub_task_max_logits;
AlignedBF16Vector* bf16_queries;
std::vector<int16_t, hwy::AlignedAllocator<int16_t>>* int16_queries;
hwy::AlignedVector<int8_t>* int8_queries;
AlignedFloatVector* float_queries;
AlignedFloatVector* q_scales;
// Inverse timescales for RoPE computation.
Expand Down
3 changes: 3 additions & 0 deletions gemma/configs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ constexpr std::pair<const char*, AttentionImpl> kAttentionImplNameToEnum[] = {
{"flash_transposed_qs_bf16", AttentionImpl::kFlashTransposedQsBF16},
{"flash_transposed_qs_int16", AttentionImpl::kFlashTransposedQsInt16},
{"flash_matrix_accumulation", AttentionImpl::kFlashMatrixAccumulation},
{"int8_matrix_accumulation", AttentionImpl::kInt8MatrixAccumulation},
};

std::string GetAttentionImplName(AttentionImpl impl) {
Expand Down Expand Up @@ -771,6 +772,8 @@ std::string KVEncodingToString(KVEncoding encoding) {
return "Int8TwoTranspositions";
case KVEncoding::kBF16MatrixAccumulation:
return "BF16MatrixAccumulation";
case KVEncoding::kInt8MatrixAccumulation:
return "Int8MatrixAccumulation";
default:
return "Unknown";
}
Expand Down
2 changes: 2 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ enum class KVEncoding {
kInt8 = 5,
kInt8TwoTranspositions = 6,
kBF16MatrixAccumulation = 7,
kInt8MatrixAccumulation = 8,
};

// Returns a string representation of the KVEncoding.
Expand All @@ -106,6 +107,7 @@ enum class AttentionImpl {
kFlashTransposedQsBF16,
kFlashTransposedQsInt16,
kFlashMatrixAccumulation,
kInt8MatrixAccumulation,
kSentinel,
};

Expand Down
47 changes: 36 additions & 11 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ HWY_INLINE void QDotKTile148BF16NotNative(
const size_t kNF = hn::Lanes(df);
const float* HWY_RESTRICT q_base[kVTileSize];
for (size_t i = 0; i < kVTileSize; ++i) {
q_base[i] = reinterpret_cast<const float*>(q + q_offsets[i]);
q_base[i] = HWY_RCAST_ALIGNED(const float*, q + q_offsets[i]);
}
const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF));
for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) {
Expand Down Expand Up @@ -433,7 +433,7 @@ HWY_INLINE void QDotKTile148BF16Native(
const size_t kNF = hn::Lanes(df);
const float* HWY_RESTRICT q_base[kVTileSize];
for (size_t i = 0; i < kVTileSize; ++i) {
q_base[i] = reinterpret_cast<const float*>(q + q_offsets[i]);
q_base[i] = HWY_RCAST_ALIGNED(const float*, q + q_offsets[i]);
}
const BF16* HWY_RESTRICT k_base = k.Row(pos / (2 * kNF));
for (size_t i = 0; i < half_cols; ++i, k_base += kNF * 4) {
Expand Down Expand Up @@ -1356,9 +1356,9 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
// tile base can point to same tile as previous loop iteration, hence no
// HWY_RESTRICT
// KVs are unaligned and we only use unaligned loads in this implementation.
const KV_T* tile_base =
reinterpret_cast<const KV_T*>(kvs[current_kv_idx].RowBytes(
(position - current_kv_start_offset) / kTileSize));
const KV_T* tile_base = HWY_RCAST_ALIGNED(
const KV_T*, kvs[current_kv_idx].RowBytes(
(position - current_kv_start_offset) / kTileSize));

const KV_T* v_tile =
tile_base + qkv_dim * kTileSize + (pos_in_tile)*qkv_dim;
Expand Down Expand Up @@ -1396,7 +1396,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
// After end of the tile, we have kTileSize * 2 bfloat16 for the
// microscaling scales for K and V.
const BF16* microscaling_scales_k =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
HWY_RCAST_ALIGNED(const BF16*, tile_base + qkv_dim * 2 * kTileSize) +
pos_in_tile;
MultiplyByScale<kNumQueries>(df, microscaling_scales_k, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
Expand Down Expand Up @@ -1439,7 +1439,8 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
if constexpr (IsInt16<Q_T>() && kUseMicroScaling) {
if (query_idx == 0) { // update only when needed
const BF16* microscaling_scales_v =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
HWY_RCAST_ALIGNED(const BF16*,
tile_base + qkv_dim * 2 * kTileSize) +
kTileSize + pos_in_tile;
const PackedSpan<const BF16> scales_span =
MakeConstSpan(microscaling_scales_v, 2 * hn::Lanes(df));
Expand All @@ -1456,7 +1457,7 @@ HWY_NOINLINE void TileFlashAttentionReturnExpSumsAndMaxLogits(
q_scales_s_ptr, max_v_scale);
if constexpr (kUseMicroScaling) {
const BF16* microscaling_scales_v =
reinterpret_cast<const BF16*>(tile_base + qkv_dim * 2 * kTileSize) +
HWY_RCAST_ALIGNED(const BF16*, tile_base + qkv_dim * 2 * kTileSize) +
kTileSize + pos_in_tile;
MultiplyByScale<kNumQueries>(df, microscaling_scales_v, x_0_p_0, x_0_p_1,
x_1_p_0, x_1_p_1, x_2_p_0, x_2_p_1, x_3_p_0,
Expand Down Expand Up @@ -1570,6 +1571,9 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsInt16(
last_pos_per_query, att_cap, att_out, exp_denominator_sums, max_logits);
}

template <typename MatT>
MatT GetKVTypeHelper(const hwy::Span<const MatPtrT<MatT>>&);

void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation(
hwy::Span<const MatPtr> kvs, size_t q_count,
const BF16* HWY_RESTRICT q_base,
Expand All @@ -1578,9 +1582,30 @@ void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulation(
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
TileFlashAttentionReturnExpSumsAndMaxLogitsBF16_Macro(
kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query,
att_cap, att_out, exp_denominator_sums, max_logits);
using KV_T = decltype(GetKVTypeHelper(kv_t));
if constexpr (IsBF16<KV_T>()) {
TileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kv_t, q_count, q_base, {}, start_pos_per_query, last_pos_per_query,
att_cap, att_out, exp_denominator_sums, max_logits);
}
});
}

void DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8(
hwy::Span<const MatPtr> kvs, size_t q_count,
const int8_t* HWY_RESTRICT q_base, hwy::Span<const float> q_scales,
hwy::Span<const size_t> start_pos_per_query,
hwy::Span<const size_t> last_pos_per_query, const float att_cap,
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums,
float* HWY_RESTRICT max_logits) {
CallUpcastedKVs(kvs, [&](const auto& kv_t) {
using KV_T = decltype(GetKVTypeHelper(kv_t));
if constexpr (IsInt8<KV_T>()) {
TileFlashAttentionReturnExpSumsAndMaxLogitsBF16(
kv_t, q_count, q_base, q_scales, start_pos_per_query,
last_pos_per_query, att_cap, att_out, exp_denominator_sums,
max_logits);
}
});
}

Expand Down
9 changes: 9 additions & 0 deletions gemma/flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ namespace gcpp {
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
void \
DispatchTileFlashAttentionReturnExpSumsAndMaxLogitsMatrixAccumulationInt8( \
hwy::Span<const MatPtr> kvs, size_t q_count, \
const int8_t* HWY_RESTRICT q_base, hwy::Span<const float> q_scales, \
hwy::Span<const size_t> start_pos_per_query, \
hwy::Span<const size_t> last_pos_per_query, const float att_cap, \
MatPtrT<float>& att_out, float* HWY_RESTRICT exp_denominator_sums, \
float* HWY_RESTRICT max_logits); \
\
/* NOLINTNEXTLINE(google-readability-namespace-comments) */ \
} // namespace NAMESPACE

Expand Down
Loading
Loading