From 1f4f899b640ba22cf75be953f95574bd6ac81d3b Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:52 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/CMakeLists.txt | 3 + .../webgpu/scripts/test_webgpu_native_ci.sh | 4 +- .../webgpu/test/native/test_dispatch_2d.cpp | 88 +++++++++++++++++++ backends/webgpu/test/ops/test_sdpa.py | 3 + backends/webgpu/test/test_webgpu_native.cpp | 12 +++ 5 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 backends/webgpu/test/native/test_dispatch_2d.cpp diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 15a1e7bfd10..22f3c892e05 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -163,6 +163,9 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST) add_webgpu_native_test( webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp ) + add_webgpu_native_test( + webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp + ) # Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) + # its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON. diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index 38195535732..0ed8c88e3b2 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -143,7 +143,7 @@ cmake \ "${EXECUTORCH_ROOT}" # ── Build + run every native test target that exists in this tree ──────────── -TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test) +TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dispatch_2d_test) BIN_DIR="${BUILD_DIR}/backends/webgpu" # Which targets are defined depends on which diffs are landed (native_test + @@ -212,6 +212,8 @@ if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then "${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}" fi [[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test" +# Device-free: pure 2D workgroup-count fold unit test (no .pte, no GPU). +[[ -x "${BIN_DIR}/webgpu_dispatch_2d_test" ]] && "${BIN_DIR}/webgpu_dispatch_2d_test" echo "=== WebGPU native tests on Dawn: all run targets passed ===" diff --git a/backends/webgpu/test/native/test_dispatch_2d.cpp b/backends/webgpu/test/native/test_dispatch_2d.cpp new file mode 100644 index 00000000000..bc4e4583901 --- /dev/null +++ b/backends/webgpu/test/native/test_dispatch_2d.cpp @@ -0,0 +1,88 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +using executorch::backends::webgpu::utils::fold_workgroup_count_2d; +using executorch::backends::webgpu::utils::WgCount; + +namespace { + +// Device-free unit test for the pure 2D workgroup-count fold that lifts the +// 65535 per-dim dispatch cap. Exercises the fold arithmetic only — no GPU. +int g_failures = 0; + +void expect_fold( + uint32_t count, + uint32_t max_count, + uint32_t want_x, + uint32_t want_y) { + WgCount got = fold_workgroup_count_2d(count, max_count, "test"); + bool ok = got.x == want_x && got.y == want_y && + static_cast(got.x) * got.y >= count; + printf( + "%s fold(%u, max=%u) = {%u, %u} (want {%u, %u})\n", + ok ? "PASS:" : "FAIL:", + count, + max_count, + got.x, + got.y, + want_x, + want_y); + if (!ok) { + g_failures++; + } +} + +void expect_throw(uint32_t count, uint32_t max_count) { + bool threw = false; + try { + fold_workgroup_count_2d(count, max_count, "test"); + } catch (const std::exception&) { + threw = true; + } + printf( + "%s fold(%u, max=%u) throws (needs a 3rd dispatch dimension)\n", + threw ? "PASS:" : "FAIL:", + count, + max_count); + if (!threw) { + g_failures++; + } +} + +} // namespace + +int main() { + const uint32_t kMax = 65535u; + // 1D fast path: count <= max -> {count, 1}, byte-identical to the old path. + expect_fold(1u, kMax, 1u, 1u); + expect_fold(kMax - 1u, kMax, kMax - 1u, 1u); + expect_fold(kMax, kMax, kMax, 1u); + // Fold to 2D: count > max -> {max, div_up(count, max)}. + expect_fold(kMax + 1u, kMax, kMax, 2u); + expect_fold(2u * kMax, kMax, kMax, 2u); + expect_fold(2u * kMax + 1u, kMax, kMax, 3u); + // Prefill-scale QK counts (tiled grid = Hq*ceil(S/4)*ceil(ctx/4)/wg) that + // exceed kMax and must fold. + expect_fold(131072u, kMax, kMax, 3u); // S=2048: 32*512*512/64 + expect_fold(2097152u, kMax, kMax, 33u); // deep fold (large-S stress) + // count > max^2 needs a 3rd dispatch dimension -> throws (out of scope). + expect_throw(kMax * kMax + 1u, kMax); + + if (g_failures != 0) { + printf("\nFAIL: %d dispatch_2d fold case(s) failed\n", g_failures); + return 1; + } + printf("\nAll dispatch_2d fold tests passed\n"); + return 0; +} diff --git a/backends/webgpu/test/ops/test_sdpa.py b/backends/webgpu/test/ops/test_sdpa.py index 1f7a8242591..960c62e5203 100644 --- a/backends/webgpu/test/ops/test_sdpa.py +++ b/backends/webgpu/test/ops/test_sdpa.py @@ -61,6 +61,9 @@ class SdpaConfig: SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127), # D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load. SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0), + # 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV (cap+1). + SdpaConfig("llama1b_prefill_512", 32, 8, 64, 512, 512, 0), + SdpaConfig("llama1b_prefill_2048", 32, 8, 64, 2048, 2048, 0), ] diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 6b57254ca33..25d422d9943 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -750,6 +750,18 @@ static const SdpaConfig kSdpaConfigs[] = { 16.0f, /*required=*/false, /*expect_reject=*/true}, + // 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV + // (cap+1). + {"llama1b_prefill_512", 32, 8, 64, 512, 512, 0, 16.0f, /*required=*/true}, + {"llama1b_prefill_2048", + 32, + 8, + 64, + 2048, + 2048, + 0, + 16.0f, + /*required=*/true}, }; // Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync).