From 8011be1df523f17d88845262c49008eedff21e76 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 28 Jun 2026 09:22:15 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- .../ops/quantized_linear/QuantizedLinear.cpp | 66 ++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp index 6da05ff2010..828acbd3f0e 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp +++ b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp @@ -17,6 +17,7 @@ #include #include #include +#include namespace executorch::backends::webgpu { @@ -256,12 +257,73 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"}); + const size_t dispatch_idx = + graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"}); + + // Dynamic shapes: recompute dispatch + params.M for the live M. + WGPUBuffer params_buf = uniform_buffer; + graph.add_tensor_resize_hook( + in_id, + [in_id, + out_id, + K, + N, + K_packed, + gs, + padded_N, + has_bias, + wg_size, + use_gemv, + dispatch_idx, + params_buf](WebGPUGraph& g) { + const auto& d = g.cur_dims(in_id); + const uint64_t numel = utils::numel_of(d); + const uint32_t m = + static_cast(numel / static_cast(K)); + if (m == 0u) { + throw std::runtime_error("WebGPU linear_q4gsw: live M == 0"); + } + uint32_t wgc; + if (use_gemv) { + const uint64_t outputs = static_cast(m) * N; + if (outputs > UINT32_MAX) { + throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range"); + } + wgc = utils::clamp_workgroup_count( + g.device(), static_cast(outputs)); + } else { + const int64_t total_tiles = utils::div_up(m, kQ4gswTileM) * + utils::div_up(N, kQ4gswTileN); + if (total_tiles > static_cast(UINT32_MAX)) { + throw std::runtime_error( + "WebGPU linear_q4gsw: tile count exceeds the dispatch limit"); + } + wgc = utils::compute_1d_workgroup_count( + g.device(), + static_cast(total_tiles), + wg_size, + "linear_q4gsw(resize)"); + } + Q4gswParams p = {}; + p.M = m; + p.N = N; + p.K = K; + p.K_packed = K_packed; + p.group_size = gs; + p.padded_N = padded_N; + p.has_bias = has_bias; + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = wgc; + std::vector od(d.begin(), d.end()); + od.back() = static_cast(N); + g.set_cur_dims(out_id, od); + }); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - wgpuBufferRelease(uniform_buffer); + // Graph owns it so the resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } } // namespace