diff --git a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp index 6da05ff2010..29acec2be3a 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp +++ b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include namespace executorch::backends::webgpu { @@ -39,6 +41,42 @@ static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes"); constexpr int64_t kQ4gswTileM = 4; constexpr int64_t kQ4gswTileN = 4; +// Workgroup count for a linear_q4gsw dispatch (GEMV coop4 or tiled GEMM), with +// the range/zero guards shared by the build-time path and the resize hook. +uint32_t compute_q4gsw_workgroup_count( + WGPUDevice device, + bool use_gemv, + uint32_t m, + uint32_t n, + uint32_t wg_size, + const char* op_name) { + if (use_gemv) { + // coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N. + const uint64_t outputs = + static_cast(m) * static_cast(n); + if (outputs == 0u || outputs > UINT32_MAX) { + throw std::runtime_error( + std::string("WebGPU ") + op_name + ": M*N out of range"); + } + const uint32_t wgc = + utils::clamp_workgroup_count(device, static_cast(outputs)); + if (wgc == 0u) { + throw std::runtime_error( + std::string("WebGPU ") + op_name + ": zero GEMV dispatch"); + } + return wgc; + } + const int64_t total_tiles = utils::div_up(m, kQ4gswTileM) * + utils::div_up(n, kQ4gswTileN); + if (total_tiles > static_cast(UINT32_MAX)) { + throw std::runtime_error( + std::string("WebGPU ") + op_name + + ": tile count exceeds the 1D dispatch limit"); + } + return utils::compute_1d_workgroup_count( + device, static_cast(total_tiles), wg_size, op_name); +} + // et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out]. void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { const int in_id = args.at(0); @@ -122,29 +160,8 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u); const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL; - uint32_t workgroup_count; - if (use_gemv) { - // coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N. - const uint64_t outputs = - static_cast(M) * static_cast(N); - if (outputs == 0u || outputs > UINT32_MAX) { - throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range"); - } - workgroup_count = - utils::clamp_workgroup_count(device, static_cast(outputs)); - if (workgroup_count == 0u) { - throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch"); - } - } else { - const int64_t total_tiles = utils::div_up(M, kQ4gswTileM) * - utils::div_up(N, kQ4gswTileN); - if (total_tiles > static_cast(UINT32_MAX)) { - throw std::runtime_error( - "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); - } - workgroup_count = utils::compute_1d_workgroup_count( - device, static_cast(total_tiles), wg_size, "linear_q4gsw"); - } + const uint32_t workgroup_count = compute_q4gsw_workgroup_count( + device, use_gemv, M, N, wg_size, "linear_q4gsw"); // Optional bias: real buffer if present, else a dummy for the fixed layout. uint32_t has_bias = 0; @@ -256,12 +273,68 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { bg_desc.entries = bg_entries; WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); - graph.add_dispatch({pipeline, bind_group, workgroup_count, "linear_q4gsw"}); + const size_t dispatch_idx = graph.add_dispatch( + {pipeline, bind_group, workgroup_count, "linear_q4gsw"}); + + // Dynamic shapes: recompute dispatch + params.M for the live M. + WGPUBuffer params_buf = uniform_buffer; + graph.add_tensor_resize_hook( + in_id, + [in_id, + out_id, + M, + K, + N, + K_packed, + gs, + padded_N, + has_bias, + wg_size, + use_gemv, + dispatch_idx, + params_buf](WebGPUGraph& g) { + const auto& d = g.cur_dims(in_id); + if (d.empty()) { + throw std::runtime_error("WebGPU linear_q4gsw: empty input dims"); + } + const uint64_t numel = utils::numel_of(d); + if (numel % static_cast(K) != 0u) { + throw std::runtime_error( + "WebGPU linear_q4gsw: live input numel not a multiple of K"); + } + const uint32_t m = + static_cast(numel / static_cast(K)); + if (m == 0u) { + throw std::runtime_error("WebGPU linear_q4gsw: live M == 0"); + } + // Buffers/bind-groups were sized for the build-time max M; a larger + // live M would write out of bounds. + if (m > M) { + throw std::runtime_error( + "WebGPU linear_q4gsw: live M exceeds the build-time max"); + } + const uint32_t wgc = compute_q4gsw_workgroup_count( + g.device(), use_gemv, m, N, wg_size, "linear_q4gsw(resize)"); + Q4gswParams p = {}; + p.M = m; + p.N = N; + p.K = K; + p.K_packed = K_packed; + p.group_size = gs; + p.padded_N = padded_N; + p.has_bias = has_bias; + wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p)); + g.dispatch_at(dispatch_idx).workgroup_count_x = wgc; + std::vector od(d.begin(), d.end()); + od.back() = static_cast(N); + g.set_cur_dims(out_id, od); + }); wgpuShaderModuleRelease(shader); wgpuBindGroupLayoutRelease(bgl); wgpuPipelineLayoutRelease(pipeline_layout); - wgpuBufferRelease(uniform_buffer); + // Graph owns it so the resize hook can rewrite it; freed in the dtor. + graph.own_uniform_buffer(uniform_buffer); } } // namespace