From d29e08a004f1c5c3e4415bb7c226036f2e070cff Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Tue, 30 Jun 2026 14:15:41 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/runtime/ops/mul/BinaryOp.cpp | 15 ++++++++------- backends/webgpu/runtime/ops/mul/binary_mul.wgsl | 7 +++++-- backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h | 9 ++++++--- backends/webgpu/runtime/ops/permute/Permute.cpp | 5 +++-- backends/webgpu/runtime/ops/permute/permute.wgsl | 7 +++++-- .../webgpu/runtime/ops/permute/permute_wgsl.h | 9 ++++++--- 6 files changed, 33 insertions(+), 19 deletions(-) diff --git a/backends/webgpu/runtime/ops/mul/BinaryOp.cpp b/backends/webgpu/runtime/ops/mul/BinaryOp.cpp index 2ccb6c0e1bf..408f34cf0df 100644 --- a/backends/webgpu/runtime/ops/mul/BinaryOp.cpp +++ b/backends/webgpu/runtime/ops/mul/BinaryOp.cpp @@ -63,8 +63,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector& args) { uint32_t wg_size = utils::clamp_workgroup_size(device, kBinaryMulWorkgroupSizeX); - uint32_t workgroup_count = - utils::compute_1d_workgroup_count(device, out_meta.numel, wg_size, "mul"); + utils::WgCount workgroup_count = + utils::compute_2d_workgroup_count(device, out_meta.numel, wg_size, "mul"); WGPUConstantEntry wg_size_constant = {}; wg_size_constant.key = {"wg_size", WGPU_STRLEN}; @@ -165,8 +165,8 @@ void mul_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - const size_t dispatch_idx = - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + const size_t dispatch_idx = graph.add_dispatch( + {pipeline, bind_group, workgroup_count.x, "mul", workgroup_count.y}); // Dynamic shapes: rebuild all 3 broadcast TensorMeta UBOs + dispatch. WGPUBuffer o_buf = out_meta_buf, a_buf = in1_meta_buf, b_buf = in2_meta_buf; @@ -199,9 +199,10 @@ void mul_impl(WebGPUGraph& graph, const std::vector& args) { wgpuQueueWriteBuffer(g.queue(), o_buf, 0, &om, sizeof(om)); wgpuQueueWriteBuffer(g.queue(), a_buf, 0, &am, sizeof(am)); wgpuQueueWriteBuffer(g.queue(), b_buf, 0, &bm, sizeof(bm)); - g.dispatch_at(dispatch_idx).workgroup_count_x = - utils::compute_1d_workgroup_count( - g.device(), om.numel, wg_size, "mul(resize)"); + const utils::WgCount wgc = utils::compute_2d_workgroup_count( + g.device(), om.numel, wg_size, "mul(resize)"); + g.dispatch_at(dispatch_idx).workgroup_count_x = wgc.x; + g.dispatch_at(dispatch_idx).workgroup_count_y = wgc.y; }; graph.add_tensor_resize_hook(in1_id, mul_resize); graph.add_tensor_resize_hook(in2_id, mul_resize); diff --git a/backends/webgpu/runtime/ops/mul/binary_mul.wgsl b/backends/webgpu/runtime/ops/mul/binary_mul.wgsl index 1708cf08792..66722968e44 100644 --- a/backends/webgpu/runtime/ops/mul/binary_mul.wgsl +++ b/backends/webgpu/runtime/ops/mul/binary_mul.wgsl @@ -15,8 +15,11 @@ struct TensorMeta { override wg_size: u32 = 64u; @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + // 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel). + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); if (idx >= out_meta.numel) { return; } diff --git a/backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h b/backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h index 4530d70e3dd..1c700615c7f 100644 --- a/backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h +++ b/backends/webgpu/runtime/ops/mul/binary_mul_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from binary_mul.wgsl - DO NOT EDIT. -// wgsl-sha256: e7f77426cbaf48e6085e0d882522c027302ec97ef017b86a2275eed9820f7891 +// wgsl-sha256: cca69c3428f37f293942637e23f664225dec81a56f184bcb63185b6629dd155e inline constexpr const char* kBinaryMulWGSL = R"( @group(0) @binding(0) var input1: array; @group(0) @binding(1) var input2: array; @@ -32,8 +32,11 @@ struct TensorMeta { override wg_size: u32 = 64u; @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { - let idx = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + // 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel). + let idx = gid.x + gid.y * (num_workgroups.x * wg_size); if (idx >= out_meta.numel) { return; } diff --git a/backends/webgpu/runtime/ops/permute/Permute.cpp b/backends/webgpu/runtime/ops/permute/Permute.cpp index 5062c33cec1..c3b9b5c70c6 100644 --- a/backends/webgpu/runtime/ops/permute/Permute.cpp +++ b/backends/webgpu/runtime/ops/permute/Permute.cpp @@ -92,7 +92,7 @@ void permute_impl(WebGPUGraph& graph, const std::vector& args) { uint32_t wg_size = utils::clamp_workgroup_size(device, kPermuteWorkgroupSizeX); - uint32_t workgroup_count = utils::compute_1d_workgroup_count( + utils::WgCount workgroup_count = utils::compute_2d_workgroup_count( device, out_meta.numel, wg_size, "permute"); WGPUConstantEntry wg_size_constant = {}; @@ -176,7 +176,8 @@ void permute_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count}); + graph.add_dispatch( + {pipeline, bind_group, workgroup_count.x, "permute", workgroup_count.y}); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); diff --git a/backends/webgpu/runtime/ops/permute/permute.wgsl b/backends/webgpu/runtime/ops/permute/permute.wgsl index 521cfac1e66..e62fa5624d5 100644 --- a/backends/webgpu/runtime/ops/permute/permute.wgsl +++ b/backends/webgpu/runtime/ops/permute/permute.wgsl @@ -18,8 +18,11 @@ struct Params { override wg_size: u32 = 64u; @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { - let out_bufi = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + // 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel). + let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size); if (out_bufi >= out_meta.numel) { return; } diff --git a/backends/webgpu/runtime/ops/permute/permute_wgsl.h b/backends/webgpu/runtime/ops/permute/permute_wgsl.h index 6ec41cc8446..b3d4684d54b 100644 --- a/backends/webgpu/runtime/ops/permute/permute_wgsl.h +++ b/backends/webgpu/runtime/ops/permute/permute_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from permute.wgsl - DO NOT EDIT. -// wgsl-sha256: d34f59730cda7317589b6ed5691a1ccab8666b9c94e17ac2cb3658b036300197 +// wgsl-sha256: 05884aeb14426c979ea037b066266d8cab11f4fed76ee21ee8778e7fc13ad84e inline constexpr const char* kPermuteWGSL = R"( @group(0) @binding(0) var input: array; @group(0) @binding(1) var output: array; @@ -35,8 +35,11 @@ struct Params { override wg_size: u32 = 64u; @compute @workgroup_size(wg_size, 1, 1) -fn main(@builtin(global_invocation_id) gid: vec3) { - let out_bufi = gid.x; +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_workgroups: vec3) { + // 2D-folded flat index (lifts the 65535 1D-dispatch cap for large numel). + let out_bufi = gid.x + gid.y * (num_workgroups.x * wg_size); if (out_bufi >= out_meta.numel) { return; }