From d54c2d2b410091396d18b6530cd5026a86e35473 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:48 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/WebGPUGraph.cpp | 9 ++- backends/webgpu/runtime/WebGPUGraph.h | 1 + backends/webgpu/runtime/WebGPUUtils.h | 55 ++++++++++++++++--- backends/webgpu/runtime/ops/add/BinaryOp.cpp | 14 +++-- .../webgpu/runtime/ops/add/binary_add.wgsl | 6 +- .../webgpu/runtime/ops/add/binary_add_wgsl.h | 8 ++- backends/webgpu/runtime/ops/sdpa/Sdpa.cpp | 45 +++++++++------ .../ops/sdpa/sdpa_compute_attn_weights.wgsl | 12 ++-- .../ops/sdpa/sdpa_compute_attn_weights_wgsl.h | 14 +++-- .../runtime/ops/sdpa/sdpa_compute_out.wgsl | 12 ++-- .../runtime/ops/sdpa/sdpa_compute_out_wgsl.h | 14 +++-- .../webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl | 7 ++- .../runtime/ops/sdpa/sdpa_softmax_wgsl.h | 9 ++- 13 files changed, 146 insertions(+), 60 deletions(-) diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 1c2d41f1214..2965444b938 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -775,7 +775,7 @@ void WebGPUGraph::execute() { wgpuComputePassEncoderSetBindGroup( pass, 0, dispatch.bind_group, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( - pass, dispatch.workgroup_count_x, 1, 1); + pass, dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); #ifdef WGPU_BACKEND_ENABLE_PROFILING @@ -783,7 +783,7 @@ void WebGPUGraph::execute() { qp->record( static_cast(i), dispatch.kernel_name, - {dispatch.workgroup_count_x, 1, 1}, + {dispatch.workgroup_count_x, dispatch.workgroup_count_y, 1}, {1, 1, 1}); } #endif // WGPU_BACKEND_ENABLE_PROFILING @@ -855,7 +855,10 @@ void WebGPUGraph::execute() { wgpuComputePassEncoderSetBindGroup( pass, 0, dispatches_[i].bind_group, 0, nullptr); wgpuComputePassEncoderDispatchWorkgroups( - pass, dispatches_[i].workgroup_count_x, 1, 1); + pass, + dispatches_[i].workgroup_count_x, + dispatches_[i].workgroup_count_y, + 1); wgpuComputePassEncoderEnd(pass); wgpuComputePassEncoderRelease(pass); } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 779467b4181..6064545ff3c 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -48,6 +48,7 @@ struct WebGPUDispatch { WGPUBindGroup bind_group = nullptr; uint32_t workgroup_count_x = 1; std::string kernel_name; // bench label + uint32_t workgroup_count_y = 1; // 2D fold (>65535); 1 = unchanged 1D path // DMA copy command; default Compute keeps existing positional inits valid. enum class Kind { Compute, Copy }; Kind kind = Kind::Compute; diff --git a/backends/webgpu/runtime/WebGPUUtils.h b/backends/webgpu/runtime/WebGPUUtils.h index 754fc50e573..2225f2d976a 100644 --- a/backends/webgpu/runtime/WebGPUUtils.h +++ b/backends/webgpu/runtime/WebGPUUtils.h @@ -44,6 +44,41 @@ inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) { return desired; } +struct WgCount { + uint32_t x; + uint32_t y; +}; + +// Device's max workgroups per dispatch dimension; the WebGPU spec-default floor +// (65535) if the query fails — never under-reports a real device's capacity. +inline uint32_t queried_max_workgroups(WGPUDevice device) { + WGPULimits limits = {}; + return wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && + limits.maxComputeWorkgroupsPerDimension > 0 + ? limits.maxComputeWorkgroupsPerDimension + : 65535u; +} + +// Pure 2D fold of a 1D workgroup count (device-free, unit-testable): {count,1} +// when count <= max, else {max, div_up(count, max)} so a >max workload fits the +// per-dimension cap; throws if a 3rd dimension would be needed (out of scope). +// The shader reconstructs the linear index from @builtin(num_workgroups). +inline WgCount fold_workgroup_count_2d( + uint32_t count, + uint32_t max_count, + const char* op_name) { + if (count <= max_count) { + return {count, 1u}; + } + uint32_t y = (count + max_count - 1) / max_count; + if (y > max_count) { + throw std::runtime_error( + std::string("WebGPU ") + op_name + + ": workgroup count needs a 3rd dispatch dimension (unsupported)"); + } + return {max_count, y}; +} + // 1D dispatch count (mirrors Vulkan div_up); throws if > device limit. inline uint32_t compute_1d_workgroup_count( WGPUDevice device, @@ -51,13 +86,7 @@ inline uint32_t compute_1d_workgroup_count( uint32_t workgroup_size, const char* op_name) { uint32_t count = div_up(num_threads, workgroup_size); - WGPULimits limits = {}; - uint32_t max_count = - wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && - limits.maxComputeWorkgroupsPerDimension > 0 - ? limits.maxComputeWorkgroupsPerDimension - : 65535u; // WebGPU spec-default floor - if (count > max_count) { + if (count > queried_max_workgroups(device)) { throw std::runtime_error( std::string("WebGPU ") + op_name + ": workgroup count exceeds the 1D dispatch limit"); @@ -65,6 +94,18 @@ inline uint32_t compute_1d_workgroup_count( return count; } +// 2D dispatch count: fold the 1D count across x/y when it exceeds the per-dim +// limit (lifts the cap, e.g. for SDPA prefill). Same fast path as compute_1d. +inline WgCount compute_2d_workgroup_count( + WGPUDevice device, + uint32_t num_threads, + uint32_t workgroup_size, + const char* op_name) { + uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size; + return fold_workgroup_count_2d( + count, queried_max_workgroups(device), op_name); +} + // Create a uniform buffer mapped-at-creation, copy `size` bytes in, and unmap. inline WGPUBuffer make_uniform(WGPUDevice device, const void* data, size_t size) { diff --git a/backends/webgpu/runtime/ops/add/BinaryOp.cpp b/backends/webgpu/runtime/ops/add/BinaryOp.cpp index ddf54b85102..ca454df8b24 100644 --- a/backends/webgpu/runtime/ops/add/BinaryOp.cpp +++ b/backends/webgpu/runtime/ops/add/BinaryOp.cpp @@ -53,8 +53,8 @@ void add_impl(WebGPUGraph& graph, const std::vector& args) { uint32_t wg_size = utils::clamp_workgroup_size(device, kBinaryAddWorkgroupSizeX); - uint32_t workgroup_count = - utils::compute_1d_workgroup_count(device, num_elements, wg_size, "add"); + utils::WgCount workgroup_count = + utils::compute_2d_workgroup_count(device, num_elements, wg_size, "add"); WGPUConstantEntry wg_size_constant = {}; wg_size_constant.key = {"wg_size", WGPU_STRLEN}; @@ -158,7 +158,8 @@ void add_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + graph.add_dispatch( + {pipeline, bind_group, workgroup_count.x, "", workgroup_count.y}); const size_t dispatch_idx = graph.num_dispatches() - 1; // Dynamic shapes: recompute numel/dispatch; out follows the larger operand. @@ -180,9 +181,10 @@ void add_impl(WebGPUGraph& graph, const std::vector& args) { p.num_elements = static_cast(numel); p.alpha = alpha; wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); - g.dispatch_at(dispatch_idx).workgroup_count_x = - utils::compute_1d_workgroup_count( - g.device(), static_cast(numel), wg_size, "add(resize)"); + const utils::WgCount wgc = utils::compute_2d_workgroup_count( + g.device(), static_cast(numel), wg_size, "add(resize)"); + g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x; + g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y; }; graph.add_tensor_resize_hook(in1_id, add_resize); graph.add_tensor_resize_hook(in2_id, add_resize); diff --git a/backends/webgpu/runtime/ops/add/binary_add.wgsl b/backends/webgpu/runtime/ops/add/binary_add.wgsl index ac88f184c6b..c1cb9d5ffd7 100644 --- a/backends/webgpu/runtime/ops/add/binary_add.wgsl +++ b/backends/webgpu/runtime/ops/add/binary_add.wgsl @@ -11,8 +11,10 @@ struct Params { override wg_size: u32 = 256; @compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); if (idx >= params.num_elements) { return; } diff --git a/backends/webgpu/runtime/ops/add/binary_add_wgsl.h b/backends/webgpu/runtime/ops/add/binary_add_wgsl.h index 1f2614d3467..829407bb383 100644 --- a/backends/webgpu/runtime/ops/add/binary_add_wgsl.h +++ b/backends/webgpu/runtime/ops/add/binary_add_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from binary_add.wgsl - DO NOT EDIT. -// wgsl-sha256: c1ceec80c8d4d3d56986ad91ce0d7f9a57cd8467b8c3aa07a28da70e51d141d9 +// wgsl-sha256: e66bd67465c2a0296e09668df54f87605a4c91015a615f3734cdd0f140a74477 inline constexpr const char* kBinaryAddWGSL = R"( @group(0) @binding(0) var input1: array; @group(0) @binding(1) var input2: array; @@ -28,8 +28,10 @@ struct Params { override wg_size: u32 = 256; @compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); if (idx >= params.num_elements) { return; } diff --git a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp index 9ce7148821c..21111777e72 100644 --- a/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp +++ b/backends/webgpu/runtime/ops/sdpa/Sdpa.cpp @@ -159,6 +159,7 @@ void build_dispatch( WGPUBuffer uniform_buffer, uint64_t uniform_size, uint32_t workgroup_count_x, + uint32_t workgroup_count_y, uint32_t wg_size, bool retain_uniform = false, const char* kernel_name = "") { @@ -232,7 +233,12 @@ void build_dispatch( bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count_x, kernel_name}); + graph.add_dispatch( + {pipeline, + bind_group, + workgroup_count_x, + kernel_name, + workgroup_count_y}); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); @@ -273,6 +279,7 @@ static WGPUBuffer record_update_cache_dispatch( ubuf, sizeof(uc), wgc, + 1, uc_wg, retain_uniform, "update_cache"); @@ -485,7 +492,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { } const int64_t qk_tiles = Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(context_len, kSdpaTileN); - const uint32_t wgc = utils::compute_1d_workgroup_count( + const utils::WgCount wgc = utils::compute_2d_workgroup_count( device, static_cast(qk_tiles), qk_wg, "QK"); AttnWeightsParams p = make_attn_weights_params( S, Hq, Hkv, D, context_len, input_pos, g, scale); @@ -501,7 +508,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { 3, ubuf, sizeof(p), - wgc, + wgc.x, + wgc.y, qk_wg, true, "sdpa_compute_attn_weights"); @@ -512,7 +520,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { // Dispatch 4: softmax, one workgroup per (h,s) row of width context_len. { // One workgroup per (h,s) row; wg_size 1 keeps the device dispatch check. - const uint32_t wgc = utils::compute_1d_workgroup_count( + const utils::WgCount wgc = utils::compute_2d_workgroup_count( device, static_cast(Hq * S), 1, "softmax"); SoftmaxParams p = make_softmax_params(Hq, S, context_len); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); @@ -525,7 +533,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { 2, ubuf, sizeof(p), - wgc, + wgc.x, + wgc.y, 0, true, "sdpa_softmax"); @@ -537,7 +546,7 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { { const int64_t av_tiles = Hq * utils::div_up(S, kSdpaTileM) * utils::div_up(D, kSdpaTileN); - const uint32_t wgc = utils::compute_1d_workgroup_count( + const utils::WgCount wgc = utils::compute_2d_workgroup_count( device, static_cast(av_tiles), av_wg, "AV"); ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g); WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p)); @@ -552,7 +561,8 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { 3, ubuf, sizeof(p), - wgc, + wgc.x, + wgc.y, av_wg, true, "sdpa_compute_out"); @@ -629,25 +639,28 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector& args) { wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp)); const int64_t qk_tiles = Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(ctx, kSdpaTileN); - gr.dispatch_at(qk_idx).workgroup_count_x = - utils::compute_1d_workgroup_count( - gr.device(), static_cast(qk_tiles), qk_wg, "QK(resize)"); + const utils::WgCount qk_wgc = utils::compute_2d_workgroup_count( + gr.device(), static_cast(qk_tiles), qk_wg, "QK(resize)"); + gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc.x; + gr.dispatch_at(qk_idx).workgroup_count_y = qk_wgc.y; // softmax: one workgroup per (h,s) row. SoftmaxParams sp = make_softmax_params(Hq, s, ctx); wgpuQueueWriteBuffer(gr.queue(), softmax_buf, 0, &sp, sizeof(sp)); - gr.dispatch_at(softmax_idx).workgroup_count_x = - utils::compute_1d_workgroup_count( - gr.device(), static_cast(Hq * s), 1, "softmax(resize)"); + const utils::WgCount sm_wgc = utils::compute_2d_workgroup_count( + gr.device(), static_cast(Hq * s), 1, "softmax(resize)"); + gr.dispatch_at(softmax_idx).workgroup_count_x = sm_wgc.x; + gr.dispatch_at(softmax_idx).workgroup_count_y = sm_wgc.y; // AV: one thread per TM x TN tile; grid = Hq*ceil(S/TM)*ceil(D/TN). ComputeOutParams op = make_compute_out_params(s, Hq, Hkv, D, ctx, g); wgpuQueueWriteBuffer(gr.queue(), av_buf, 0, &op, sizeof(op)); const int64_t av_tiles = Hq * utils::div_up(s, kSdpaTileM) * utils::div_up(D, kSdpaTileN); - gr.dispatch_at(av_idx).workgroup_count_x = - utils::compute_1d_workgroup_count( - gr.device(), static_cast(av_tiles), av_wg, "AV(resize)"); + const utils::WgCount av_wgc = utils::compute_2d_workgroup_count( + gr.device(), static_cast(av_tiles), av_wg, "AV(resize)"); + gr.dispatch_at(av_idx).workgroup_count_x = av_wgc.x; + gr.dispatch_at(av_idx).workgroup_count_y = av_wgc.y; // Output attn has the same shape as q: [.., S, Hq, D]. gr.set_cur_dims(out_id, gr.cur_dims(q_id)); diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl index fd15767603d..014f0039048 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl @@ -53,17 +53,21 @@ fn store_qk(s: u32, c: u32, h: u32, raw: f32) { } @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { let nrt = (params.S + TM - 1u) / TM; let nct = (params.context_len + TN - 1u) / TN; let tiles = nrt * nct; let total = tiles * params.Hq; - if (gid.x >= total) { + // 2D dispatch fold: recover the linear tile index across x/y. + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); + if (idx >= total) { return; } - let h = gid.x / tiles; - let rem = gid.x % tiles; + let h = idx / tiles; + let rem = idx % tiles; let row_tile = rem / nct; let col_tile = rem % nct; let kvh = h / params.g; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h index ae250959e0e..d19de1ec214 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_attn_weights.wgsl - DO NOT EDIT. -// wgsl-sha256: 4eef09b234fd926cdc0daf18d03e39cf4fd57dfa4bc67724b4878b7dc68d1254 +// wgsl-sha256: ac1616d310a67aeb915ee7b4f650c981cc423c5251de2e7acf8775bacdfd8e56 inline constexpr const char* kSdpaComputeAttnWeightsWGSL = R"( @group(0) @binding(0) var t_attn_weights: array; @group(0) @binding(1) var t_q: array>; @@ -70,17 +70,21 @@ fn store_qk(s: u32, c: u32, h: u32, raw: f32) { } @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { let nrt = (params.S + TM - 1u) / TM; let nct = (params.context_len + TN - 1u) / TN; let tiles = nrt * nct; let total = tiles * params.Hq; - if (gid.x >= total) { + // 2D dispatch fold: recover the linear tile index across x/y. + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); + if (idx >= total) { return; } - let h = gid.x / tiles; - let rem = gid.x % tiles; + let h = idx / tiles; + let rem = idx % tiles; let row_tile = rem / nct; let col_tile = rem % nct; let kvh = h / params.g; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl index 56242b0ddde..713345c0afa 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out.wgsl @@ -64,17 +64,21 @@ fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { } @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { let nrt = (params.S + TM - 1u) / TM; let nct = (params.D + TN - 1u) / TN; let tiles = nrt * nct; let total = tiles * params.Hq; - if (gid.x >= total) { + // 2D dispatch fold: recover the linear tile index across x/y. + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); + if (idx >= total) { return; } - let h = gid.x / tiles; - let rem = gid.x % tiles; + let h = idx / tiles; + let rem = idx % tiles; let row_tile = rem / nct; let col_tile = rem % nct; let kvh = h / params.g; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h index 6bec079ac2b..80af51004f3 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_compute_out_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_compute_out.wgsl - DO NOT EDIT. -// wgsl-sha256: 2ffa0eb520b1054e43a10fd13e6b287bd35777f1cfc29bd39e9d668772528191 +// wgsl-sha256: bbd7a511032fb901f9ee245357a7d681f062ba95bd36ff6ea61d70c71c023bee inline constexpr const char* kSdpaComputeOutWGSL = R"( @group(0) @binding(0) var t_out: array>; @group(0) @binding(1) var t_attn_weights_softmax: array; @@ -81,17 +81,21 @@ fn store_out_vec4(s: u32, d0: u32, h: u32, val: vec4) { } @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { let nrt = (params.S + TM - 1u) / TM; let nct = (params.D + TN - 1u) / TN; let tiles = nrt * nct; let total = tiles * params.Hq; - if (gid.x >= total) { + // 2D dispatch fold: recover the linear tile index across x/y. + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); + if (idx >= total) { return; } - let h = gid.x / tiles; - let rem = gid.x % tiles; + let h = idx / tiles; + let rem = idx % tiles; let row_tile = rem / nct; let col_tile = rem % nct; let kvh = h / params.g; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl index 6ef223c3a98..f0e349ceaff 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax.wgsl @@ -20,9 +20,12 @@ var shared_sum: array; @compute @workgroup_size(WG_SIZE, 1, 1) fn main( @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { + @builtin(local_invocation_id) lid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { // One workgroup per (h, s) row of length context_len (= row_width). - let row_idx = wid.x; + // 2D dispatch fold: recover the linear row index. Keep the `valid` predicate + // below (NOT an early return) — the workgroupBarrier()s require uniform flow. + let row_idx = wid.x + wid.y * num_workgroups.x; let worker_id = lid.x; let base = row_idx * params.row_width; diff --git a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h index 94f0ab5790a..f1e7396fafe 100644 --- a/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h +++ b/backends/webgpu/runtime/ops/sdpa/sdpa_softmax_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from sdpa_softmax.wgsl - DO NOT EDIT. -// wgsl-sha256: e2714ec4c2400b37f6fd39c410075c519effc0273354a4f906fb924334809024 +// wgsl-sha256: bdd45cf344663533b243153200c507f41a90295751924e70452abcd5da4cdd5a inline constexpr const char* kSdpaSoftmaxWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_in: array; @@ -37,9 +37,12 @@ var shared_sum: array; @compute @workgroup_size(WG_SIZE, 1, 1) fn main( @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { + @builtin(local_invocation_id) lid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { // One workgroup per (h, s) row of length context_len (= row_width). - let row_idx = wid.x; + // 2D dispatch fold: recover the linear row index. Keep the `valid` predicate + // below (NOT an early return) — the workgroupBarrier()s require uniform flow. + let row_idx = wid.x + wid.y * num_workgroups.x; let worker_id = lid.x; let base = row_idx * params.row_width;