Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdint>
#include <cstring>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

Expand All @@ -36,14 +37,6 @@ static_assert(
sizeof(EmbeddingParams) == 32,
"EmbeddingParams must be 32 bytes");

uint64_t numel_of(const std::vector<int64_t>& dims) {
uint64_t n = 1;
for (int64_t d : dims) {
n *= static_cast<uint64_t>(d);
}
return n;
}

// arg order mirrors Vulkan EmbeddingQ4gsw.cpp.
void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int weight_id = args.at(0);
Expand Down Expand Up @@ -102,7 +95,7 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// Leading index dims flatten row-major (mirrors Vulkan num_indices).
const uint64_t out_numel = numel_of(out.dims);
const uint64_t out_numel = utils::numel_of(out.dims);
const uint32_t num_indices = static_cast<uint32_t>(out_numel / embed_dim);
const uint32_t groups_per_row = static_cast<uint32_t>(scales.dims[1]);
const uint32_t blocks_per_row = embed_dim / 32u;
Expand All @@ -119,9 +112,9 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// Per-type byte guards (no runtime dtype): indices i32, weight u8, fp32 rest.
const uint64_t indices_numel = numel_of(indices.dims);
const uint64_t weight_numel = numel_of(weight.dims);
const uint64_t scales_numel = numel_of(scales.dims);
const uint64_t indices_numel = utils::numel_of(indices.dims);
const uint64_t weight_numel = utils::numel_of(weight.dims);
const uint64_t scales_numel = utils::numel_of(scales.dims);
if (indices_numel != num_indices ||
indices.nbytes != indices_numel * sizeof(int32_t) ||
weight.nbytes != weight_numel ||
Expand Down Expand Up @@ -230,13 +223,59 @@ void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector<int>& args) {
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch(
const size_t dispatch_idx = graph.add_dispatch(
{pipeline, bind_group, workgroup_count, "embedding_q4gsw"});

// Dynamic shapes: recompute counts/dispatch; out = indices + [embed_dim].
const uint32_t gs_u = static_cast<uint32_t>(group_size);
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
indices_id,
[indices_id,
out_id,
embed_dim,
blocks_per_row,
gs_u,
groups_per_row,
bytes_per_row,
wg_size,
dispatch_idx,
params_buf](WebGPUGraph& g) {
const auto& id = g.cur_dims(indices_id);
const uint64_t ni = utils::numel_of(id);
if (ni == 0) {
throw std::runtime_error("WebGPU embedding_q4gsw: zero indices");
}
const uint64_t total_blocks = ni * blocks_per_row;
if (total_blocks > UINT32_MAX) {
throw std::runtime_error(
"WebGPU embedding_q4gsw: total_blocks exceeds uint32");
}
std::vector<int64_t> od = id;
od.push_back(static_cast<int64_t>(embed_dim));
g.set_cur_dims(out_id, od);
EmbeddingParams p = {};
p.embed_dim = embed_dim;
p.blocks_per_row = blocks_per_row;
p.num_indices = static_cast<uint32_t>(ni);
p.group_size = gs_u;
p.groups_per_row = groups_per_row;
p.bytes_per_row = bytes_per_row;
p.total_blocks = static_cast<uint32_t>(total_blocks);
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(total_blocks),
wg_size,
"embedding_q4gsw(resize)");
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
38 changes: 35 additions & 3 deletions backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>
Expand Down Expand Up @@ -187,14 +188,45 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
static_assert(
kRmsNormVec4WorkgroupSizeX == 64,
"must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl");
graph.add_dispatch({pipeline, bind_group, num_rows});
const size_t dispatch_idx =
graph.add_dispatch({pipeline, bind_group, num_rows});

// Dynamic shapes: recompute num_rows + rewrite the UBO for the live input.
WGPUBuffer params_buf = uniform_buffer;
graph.add_tensor_resize_hook(
in_id,
[in_id, out_id, row_width, epsilon, dispatch_idx, params_buf](
WebGPUGraph& g) {
const auto& d = g.cur_dims(in_id);
const uint64_t numel = utils::numel_of(d);
if (numel % static_cast<uint64_t>(row_width) != 0) {
throw std::runtime_error(
"WebGPU rms_norm: numel not a multiple of row_width");
}
const uint32_t rows =
static_cast<uint32_t>(numel / static_cast<uint64_t>(row_width));
if (rows == 0) {
throw std::runtime_error("WebGPU rms_norm: zero rows");
}
if (rows > 65535u) {
throw std::runtime_error(
"WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)");
}
RmsNormParams p = {};
p.num_rows = rows;
p.row_width = row_width;
p.epsilon = epsilon;
wgpuQueueWriteBuffer(g.queue(), params_buf, 0, &p, sizeof(p));
g.dispatch_at(dispatch_idx).workgroup_count_x = rows;
g.set_cur_dims(out_id, d);
});

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform buffer alive until release.
wgpuBufferRelease(uniform_buffer);
// Graph owns it so the resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
}

} // namespace
Expand Down
99 changes: 83 additions & 16 deletions backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ struct RotaryParams {
};
static_assert(sizeof(RotaryParams) == 32, "RotaryParams must be 32 bytes");

uint64_t numel_of(const std::vector<int64_t>& dims) {
uint64_t n = 1;
for (int64_t d : dims) {
n *= static_cast<uint64_t>(d);
}
return n;
}
// A rope dispatch: its param-uniform (rewritten on resize) and its index in the
// graph's dispatch list (so a resize hook can update the workgroup count).
struct RopeDispatch {
WGPUBuffer uniform;
size_t dispatch_index;
};

// Rotate one (x->out) with the shared shader; freqs shared between xq and xk.
void add_rope_dispatch(
RopeDispatch add_rope_dispatch(
WebGPUGraph& graph,
WGPUDevice device,
WGPUComputePipeline pipeline,
Expand All @@ -58,7 +57,8 @@ void add_rope_dispatch(
uint32_t workgroup_count) {
const uint32_t half_dim = head_dim / 2u;
// out.dims == in.dims (asserted in impl), so this matches the caller's wgc.
const uint32_t num_pairs = static_cast<uint32_t>(numel_of(out.dims) / 2u);
const uint32_t num_pairs =
static_cast<uint32_t>(utils::numel_of(out.dims) / 2u);

RotaryParams params = {};
params.n_heads = n_heads;
Expand Down Expand Up @@ -101,10 +101,12 @@ void add_rope_dispatch(
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch(
const size_t dispatch_index = graph.add_dispatch(
{pipeline, bind_group, workgroup_count, "apply_rotary_emb"});

wgpuBufferRelease(uniform_buffer);
// Graph owns it so a resize hook can rewrite it; freed in the dtor.
graph.own_uniform_buffer(uniform_buffer);
return {uniform_buffer, dispatch_index};
}

// args: [xq, xk, freqs_cos, freqs_sin, out_list(ValueList[xq_out, xk_out])].
Expand Down Expand Up @@ -164,9 +166,9 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
}

// All tensors are fp32; output shapes equal their inputs.
const uint64_t xq_numel = numel_of(xq.dims);
const uint64_t xk_numel = numel_of(xk.dims);
const uint64_t freqs_numel = numel_of(freqs_cos.dims);
const uint64_t xq_numel = utils::numel_of(xq.dims);
const uint64_t xk_numel = utils::numel_of(xk.dims);
const uint64_t freqs_numel = utils::numel_of(freqs_cos.dims);
if (freqs_numel != static_cast<uint64_t>(seq) * half_dim ||
xq.nbytes != xq_numel * sizeof(float) ||
xk.nbytes != xk_numel * sizeof(float) ||
Expand Down Expand Up @@ -246,7 +248,7 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUComputePipeline pipeline_k =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

add_rope_dispatch(
RopeDispatch q_disp = add_rope_dispatch(
graph,
device,
pipeline_q,
Expand All @@ -259,7 +261,7 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
seq,
head_dim,
xq_wgc);
add_rope_dispatch(
RopeDispatch k_disp = add_rope_dispatch(
graph,
device,
pipeline_k,
Expand All @@ -272,6 +274,71 @@ void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector<int>& args) {
seq,
head_dim,
xk_wgc);
WGPUBuffer q_ubuf = q_disp.uniform;
WGPUBuffer k_ubuf = k_disp.uniform;
const size_t q_idx = q_disp.dispatch_index;
const size_t k_idx = k_disp.dispatch_index;

// Dynamic shapes: recompute S/num_pairs + both dispatches; out follows xq/xk.
const int xq_out_id = out_list[0];
const int xk_out_id = out_list[1];
// xq_id trigger suffices: q and k co-resize on S; this updates both.
graph.add_tensor_resize_hook(
xq_id,
[xq_id,
xk_id,
xq_out_id,
xk_out_id,
n_heads_q,
n_heads_k,
head_dim,
half_dim,
wg_size,
q_idx,
k_idx,
q_ubuf,
k_ubuf](WebGPUGraph& g) {
const auto& qd = g.cur_dims(xq_id);
const auto& kd = g.cur_dims(xk_id);
if (qd.size() < 3 || kd.size() < 3) {
throw std::runtime_error(
"apply_rotary_emb(resize): q/k rank must be >= 3");
}
const uint32_t s = static_cast<uint32_t>(qd[qd.size() - 3]);
const uint64_t qn = utils::numel_of(qd);
const uint64_t kn = utils::numel_of(kd);
// pk = pq (seq=s); require k's seq == s, not silently q's.
if (static_cast<uint32_t>(kd[kd.size() - 3]) != s) {
throw std::runtime_error(
"apply_rotary_emb(resize): q and k seq lengths differ");
}
// freqs stay max-allocated; shader indexes by position (S = prefix).
RotaryParams pq = {};
pq.n_heads = n_heads_q;
pq.seq = s;
pq.head_dim = head_dim;
pq.half_dim = half_dim;
pq.num_pairs = static_cast<uint32_t>(qn / 2u);
RotaryParams pk = pq;
pk.n_heads = n_heads_k;
pk.num_pairs = static_cast<uint32_t>(kn / 2u);
wgpuQueueWriteBuffer(g.queue(), q_ubuf, 0, &pq, sizeof(pq));
wgpuQueueWriteBuffer(g.queue(), k_ubuf, 0, &pk, sizeof(pk));
g.dispatch_at(q_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(qn / 2u),
wg_size,
"apply_rotary_emb(resize)");
g.dispatch_at(k_idx).workgroup_count_x =
utils::compute_1d_workgroup_count(
g.device(),
static_cast<uint32_t>(kn / 2u),
wg_size,
"apply_rotary_emb(resize)");
g.set_cur_dims(xq_out_id, qd);
g.set_cur_dims(xk_out_id, kd);
});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
Expand Down
Loading