Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,9 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
add_webgpu_native_test(
webgpu_dynamic_shape_test test/native/test_dynamic_shape.cpp
)
add_webgpu_native_test(
webgpu_dispatch_2d_test test/native/test_dispatch_2d.cpp
)

# Manifest-driven op-test framework: a generic gtest driver (webgpu_op_test) +
# its device-free util unit test. GTest needs -DEXECUTORCH_BUILD_TESTS=ON.
Expand Down
4 changes: 3 additions & 1 deletion backends/webgpu/scripts/test_webgpu_native_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ cmake \
"${EXECUTORCH_ROOT}"

# ── Build + run every native test target that exists in this tree ────────────
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test)
TARGETS=(webgpu_native_test webgpu_dispatch_order_test webgpu_scratch_buffer_test webgpu_update_cache_test webgpu_index_test webgpu_dispatch_2d_test)
BIN_DIR="${BUILD_DIR}/backends/webgpu"

# Which targets are defined depends on which diffs are landed (native_test +
Expand Down Expand Up @@ -212,6 +212,8 @@ if [[ "${INDEX_OK}" == "1" && -x "${BIN_DIR}/webgpu_index_test" ]]; then
"${BIN_DIR}/webgpu_index_test" "${INDEX_DIR}"
fi
[[ -x "${BIN_DIR}/webgpu_scratch_buffer_test" ]] && "${BIN_DIR}/webgpu_scratch_buffer_test"
# Device-free: pure 2D workgroup-count fold unit test (no .pte, no GPU).
[[ -x "${BIN_DIR}/webgpu_dispatch_2d_test" ]] && "${BIN_DIR}/webgpu_dispatch_2d_test"

echo "=== WebGPU native tests on Dawn: all run targets passed ==="

Expand Down
88 changes: 88 additions & 0 deletions backends/webgpu/test/native/test_dispatch_2d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>

#include <cstdint>
#include <cstdio>
#include <stdexcept>

using executorch::backends::webgpu::utils::fold_workgroup_count_2d;
using executorch::backends::webgpu::utils::WgCount;

namespace {

// Device-free unit test for the pure 2D workgroup-count fold that lifts the
// 65535 per-dim dispatch cap. Exercises the fold arithmetic only — no GPU.
int g_failures = 0;

void expect_fold(
uint32_t count,
uint32_t max_count,
uint32_t want_x,
uint32_t want_y) {
WgCount got = fold_workgroup_count_2d(count, max_count, "test");
bool ok = got.x == want_x && got.y == want_y &&
static_cast<uint64_t>(got.x) * got.y >= count;
printf(
"%s fold(%u, max=%u) = {%u, %u} (want {%u, %u})\n",
ok ? "PASS:" : "FAIL:",
count,
max_count,
got.x,
got.y,
want_x,
want_y);
if (!ok) {
g_failures++;
}
}

void expect_throw(uint32_t count, uint32_t max_count) {
bool threw = false;
try {
fold_workgroup_count_2d(count, max_count, "test");
} catch (const std::exception&) {
threw = true;
}
printf(
"%s fold(%u, max=%u) throws (needs a 3rd dispatch dimension)\n",
threw ? "PASS:" : "FAIL:",
count,
max_count);
if (!threw) {
g_failures++;
}
}

} // namespace

int main() {
const uint32_t kMax = 65535u;
// 1D fast path: count <= max -> {count, 1}, byte-identical to the old path.
expect_fold(1u, kMax, 1u, 1u);
expect_fold(kMax - 1u, kMax, kMax - 1u, 1u);
expect_fold(kMax, kMax, kMax, 1u);
// Fold to 2D: count > max -> {max, div_up(count, max)}.
expect_fold(kMax + 1u, kMax, kMax, 2u);
expect_fold(2u * kMax, kMax, kMax, 2u);
expect_fold(2u * kMax + 1u, kMax, kMax, 3u);
// Prefill-scale QK counts (tiled grid = Hq*ceil(S/4)*ceil(ctx/4)/wg) that
// exceed kMax and must fold.
expect_fold(131072u, kMax, kMax, 3u); // S=2048: 32*512*512/64
expect_fold(2097152u, kMax, kMax, 33u); // deep fold (large-S stress)
// count > max^2 needs a 3rd dispatch dimension -> throws (out of scope).
expect_throw(kMax * kMax + 1u, kMax);

if (g_failures != 0) {
printf("\nFAIL: %d dispatch_2d fold case(s) failed\n", g_failures);
return 1;
}
printf("\nAll dispatch_2d fold tests passed\n");
return 0;
}
3 changes: 3 additions & 0 deletions backends/webgpu/test/ops/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class SdpaConfig:
SdpaConfig("llama1b_decode", 32, 8, 64, 1, 512, 127),
# D=6 is not a multiple of 4: the WebGPU head_dim%4 guard must reject it at load.
SdpaConfig("reject_d6", 4, 4, 6, 4, 16, 0),
# 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV (cap+1).
SdpaConfig("llama1b_prefill_512", 32, 8, 64, 512, 512, 0),
SdpaConfig("llama1b_prefill_2048", 32, 8, 64, 2048, 2048, 0),
]


Expand Down
12 changes: 12 additions & 0 deletions backends/webgpu/test/test_webgpu_native.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,18 @@ static const SdpaConfig kSdpaConfigs[] = {
16.0f,
/*required=*/false,
/*expect_reject=*/true},
// 2D-dispatch cap (>65535 wg): S=512 folds QK; S=2048 folds QK+softmax+AV
// (cap+1).
{"llama1b_prefill_512", 32, 8, 64, 512, 512, 0, 16.0f, /*required=*/true},
{"llama1b_prefill_2048",
32,
8,
64,
2048,
2048,
0,
16.0f,
/*required=*/true},
};

// Ramp denominator; mirror of test_sdpa.py::_RAMP_DENOM (keep in sync).
Expand Down
Loading