From 9f1fb83eff08b51755ac7ec753415aee59f75822 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:06 -0700 Subject: [PATCH 1/6] [ExecuTorch][WebGPU] Add et_vk.embedding_q4gsw (4-bit groupwise-symmetric quantized embedding) Pull Request resolved: https://github.com/pytorch/executorch/pull/20263 Adds the WebGPU backend handler for `et_vk.embedding_q4gsw.default` (a 4-bit groupwise-symmetric quantized embedding gather) plus the host-side integer-input infra it requires. The op is a single compute dispatch composed of one stage: one thread per 32-element block of each gathered row dequantizes the packed 4-bit table (`q = (nibble - 8) * scale`; even dim = high nibble, odd dim = low) into the fp32 output, mirroring the Vulkan `embedding_q4gsw` reference (flat buffer-backed weight; `is_linear_weight=true` is unsupported and throws). The workgroup size is a `wg_size` pipeline-override constant clamped to the device limit via `WebGPUUtils::clamp_workgroup_size`, the 1D dispatch count goes through `WebGPUUtils::compute_1d_workgroup_count` (validated before any GPU-object allocation), and the embedded WGSL string header is generated by `gen_wgsl_headers.py`. Embedding indices arrive as int64 at the program boundary but the serialized graph stores them as int32, so the shared input path is extended with a host-side `InputData` view (`{data, nbytes, host_is_int64}`) and `copy_inputs` gains three branches: a byte-for-byte fast path when host and GPU sizes match, an int64->int32 narrowing copy when the buffer is int32 and the host input is twice as wide (mirrors the Vulkan `kLong`->`kInt` staging cast), and a fail-loud throw otherwise. `WebGPUTensor` gains `elem_size`/`is_int` to drive the narrowing decision, and `update_symints_from_inputs` takes the same `InputData` vector so `execute()` builds a single input list consumed by both. ghstack-source-id: 395549280 @exported-using-ghexport Differential Revision: [D108428753](https://our.internmc.facebook.com/intern/diff/D108428753/) --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUBackend.cpp | 11 +- backends/webgpu/runtime/WebGPUGraph.cpp | 64 ++++- backends/webgpu/runtime/WebGPUGraph.h | 15 +- .../ops/embedding_q4gsw/EmbeddingQ4gsw.cpp | 248 ++++++++++++++++++ .../ops/embedding_q4gsw/embedding_q4gsw.wgsl | 50 ++++ .../embedding_q4gsw/embedding_q4gsw_wgsl.h | 74 ++++++ 7 files changed, 446 insertions(+), 17 deletions(-) create mode 100644 backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp create mode 100644 backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl create mode 100644 backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index e17ccbdbb0d..1eb75128997 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -39,6 +39,7 @@ set(WEBGPU_SRCS runtime/ops/select_as_symint/SelectAsSymint.cpp runtime/ops/quantized_linear/QuantizedLinear.cpp runtime/ops/mul/BinaryOp.cpp + runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUBackend.cpp b/backends/webgpu/runtime/WebGPUBackend.cpp index aed769da4a4..ceca89d1710 100644 --- a/backends/webgpu/runtime/WebGPUBackend.cpp +++ b/backends/webgpu/runtime/WebGPUBackend.cpp @@ -98,20 +98,21 @@ Error WebGPUBackend::execute( const size_t num_outputs = graph->output_ids().size(); // Copy inputs from EValue tensors to GPU buffers - std::vector> inputs; + std::vector inputs; inputs.reserve(num_inputs); for (size_t i = 0; i < num_inputs; i++) { const auto& tensor = args[i]->toTensor(); - inputs.emplace_back(tensor.const_data_ptr(), tensor.nbytes()); + const bool host_is_int64 = + tensor.scalar_type() == executorch::aten::ScalarType::Long; + inputs.push_back({tensor.const_data_ptr(), tensor.nbytes(), host_is_int64}); } - graph->copy_inputs(inputs); - // Fail loud as a runtime Error so a throw never crosses the backend boundary. try { + graph->copy_inputs(inputs); graph->update_symints_from_inputs(inputs); graph->propagate_resize(); } catch (const std::exception& e) { - ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what()); + ET_LOG(Error, "WebGPU input copy / symint refresh failed: %s", e.what()); return Error::Internal; } diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 1c977d130dd..d3974eab194 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -45,6 +45,19 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) { } } +bool vk_datatype_is_int(vkgraph::VkDataType dtype) { + switch (dtype) { + case vkgraph::VkDataType::BOOL: + case vkgraph::VkDataType::UINT8: + case vkgraph::VkDataType::INT8: + case vkgraph::VkDataType::INT32: + case vkgraph::VkDataType::INT64: + return true; + default: + return false; + } +} + } // namespace WebGPUGraph::WebGPUGraph() = default; @@ -61,7 +74,7 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) { } void WebGPUGraph::update_symints_from_inputs( - const std::vector>& inputs) { + const std::vector& inputs) { for (const auto& src : symint_sources_) { int pos = -1; for (size_t i = 0; i < input_ids_.size(); i++) { @@ -100,8 +113,8 @@ void WebGPUGraph::update_symints_from_inputs( // Reads the [0,..,index,..,0] element; symint sources are scalar-ish. const int64_t offset = static_cast(index) * stride; // elem_size back-derived from build-time numel (sources are static-shaped). - const void* host = inputs[pos].first; - const size_t elem_size = inputs[pos].second / static_cast(numel); + const void* host = inputs[pos].data; + const size_t elem_size = inputs[pos].nbytes / static_cast(numel); int32_t val; if (elem_size == sizeof(int64_t)) { val = static_cast(static_cast(host)[offset]); @@ -248,7 +261,9 @@ void WebGPUGraph::build( numel *= dims->Get(j); } } - tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype()); + tensor.elem_size = vk_datatype_size(vk_tensor->datatype()); + tensor.is_int = vk_datatype_is_int(vk_tensor->datatype()); + tensor.nbytes = numel * tensor.elem_size; int constant_id = vk_tensor->constant_id(); int mem_obj_id = vk_tensor->mem_obj_id(); @@ -484,16 +499,47 @@ WGPUBindGroupLayout WebGPUGraph::get_or_create_bgl( return bgl; } -void WebGPUGraph::copy_inputs( - const std::vector>& inputs) { +void WebGPUGraph::copy_inputs(const std::vector& inputs) { for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) { - if (inputs[i].second == 0) { + const InputData& in = inputs[i]; + if (in.nbytes == 0) { continue; } int tid = input_ids_[i]; const auto& tensor = tensors_[tid]; - wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, inputs[i].first, inputs[i].second); + + // Fast path: host and GPU element types match byte-for-byte. + if (in.nbytes == tensor.nbytes) { + wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes); + continue; + } + + // Narrow int64 host indices into the int32 buffer (mirrors Vulkan). + const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4; + if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) { + const size_t numel = tensor.nbytes / 4; + const int64_t* src = static_cast(in.data); + std::vector narrowed(numel); + for (size_t e = 0; e < numel; e++) { +#ifndef NDEBUG + // Index tensors (tokens/positions) are far below int32 range in + // practice; assert in debug that the narrowing is lossless. + if (static_cast(src[e]) != src[e]) { + throw std::runtime_error("WebGPU: int64 index overflows int32"); + } +#endif + narrowed[e] = static_cast(src[e]); + } + wgpuQueueWriteBuffer( + queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes); + continue; + } + + throw std::runtime_error( + "WebGPU: unsupported input copy for input " + std::to_string(i) + + " (host " + std::to_string(in.nbytes) + " bytes" + + (in.host_is_int64 ? " int64" : "") + " vs buffer " + + std::to_string(tensor.nbytes) + " bytes)"); } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 3cff09ecb6d..5bd5b93b524 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -25,6 +25,16 @@ struct WebGPUTensor { WGPUBuffer buffer = nullptr; std::vector dims; size_t nbytes = 0; + // Serialized (GPU-side) element type, used to narrow wider host inputs. + size_t elem_size = 0; + bool is_int = false; +}; + +// Host-side view of one graph input, passed to copy_inputs. +struct InputData { + const void* data = nullptr; + size_t nbytes = 0; + bool host_is_int64 = false; }; struct WebGPUDispatch { @@ -75,7 +85,7 @@ class WebGPUGraph { const executorch::runtime::NamedDataMap* named_data_map = nullptr); // Copy input tensor data from host pointers into GPU buffers. - void copy_inputs(const std::vector>& inputs); + void copy_inputs(const std::vector& inputs); // Execute all recorded dispatches. void execute(); @@ -138,8 +148,7 @@ class WebGPUGraph { } // Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl. - void update_symints_from_inputs( - const std::vector>& inputs); + void update_symints_from_inputs(const std::vector& inputs); // Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize. void add_resize_hook(int symint_id, std::function fn) { diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp new file mode 100644 index 00000000000..5801b650f27 --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp @@ -0,0 +1,248 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes). +struct EmbeddingParams { + uint32_t embed_dim; + uint32_t blocks_per_row; + uint32_t num_indices; + uint32_t group_size; + uint32_t groups_per_row; + uint32_t bytes_per_row; + uint32_t total_blocks; + uint32_t _pad; +}; +static_assert( + sizeof(EmbeddingParams) == 32, + "EmbeddingParams must be 32 bytes"); + +uint64_t numel_of(const std::vector& dims) { + uint64_t n = 1; + for (int64_t d : dims) { + n *= static_cast(d); + } + return n; +} + +// arg order mirrors Vulkan EmbeddingQ4gsw.cpp. +void embedding_q4gsw_impl(WebGPUGraph& graph, const std::vector& args) { + const int weight_id = args.at(0); + const int scales_id = args.at(1); + const int group_size_id = args.at(2); + const int indices_id = args.at(3); + const int is_linear_weight_id = args.at(4); + const int out_id = args.at(5); + + WGPUDevice device = graph.device(); + + const auto& weight = graph.get_tensor(weight_id); + const auto& scales = graph.get_tensor(scales_id); + const auto& indices = graph.get_tensor(indices_id); + const auto& out = graph.get_tensor(out_id); + + // Only the flat weight path is supported (linear-block unsupported). + bool is_linear = false; + if (graph.get_value_type(is_linear_weight_id) == + WebGPUGraph::ValueType::Bool) { + is_linear = graph.get_bool(is_linear_weight_id); + } else if ( + graph.get_value_type(is_linear_weight_id) == + WebGPUGraph::ValueType::Int) { + is_linear = graph.get_int(is_linear_weight_id) != 0; + } else { + throw std::runtime_error( + "WebGPU embedding_q4gsw: is_linear_weight must be Bool or Int"); + } + if (is_linear) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: is_linear_weight=true is unsupported"); + } + + if (weight.dims.size() < 2 || scales.dims.size() < 2 || out.dims.empty() || + indices.dims.empty()) { + throw std::runtime_error("WebGPU embedding_q4gsw: malformed dims"); + } + + const uint32_t embed_dim = static_cast(out.dims.back()); + if (embed_dim == 0 || embed_dim % 32 != 0) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: embed_dim must be a nonzero multiple of 32"); + } + if (static_cast(weight.dims[1]) * 2 != embed_dim) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: weight row stride mismatch (embed_dim/2)"); + } + + int64_t group_size = 0; + if (graph.get_value_type(group_size_id) == WebGPUGraph::ValueType::Int) { + group_size = graph.get_int(group_size_id); + } + if (group_size <= 0) { + throw std::runtime_error("WebGPU embedding_q4gsw: group_size <= 0"); + } + + // Leading index dims flatten row-major (mirrors Vulkan num_indices). + const uint64_t out_numel = numel_of(out.dims); + const uint32_t num_indices = static_cast(out_numel / embed_dim); + const uint32_t groups_per_row = static_cast(scales.dims[1]); + const uint32_t blocks_per_row = embed_dim / 32u; + const uint32_t bytes_per_row = embed_dim / 2u; + const uint64_t total_blocks = + static_cast(num_indices) * blocks_per_row; + if (static_cast(groups_per_row) * group_size != embed_dim) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: groups_per_row * group_size != embed_dim"); + } + if (weight.buffer == nullptr || scales.buffer == nullptr || + indices.buffer == nullptr || out.buffer == nullptr) { + throw std::runtime_error("WebGPU embedding_q4gsw: null buffer binding"); + } + + // Per-type byte guards (no runtime dtype): indices i32, weight u8, fp32 rest. + const uint64_t indices_numel = numel_of(indices.dims); + const uint64_t weight_numel = numel_of(weight.dims); + const uint64_t scales_numel = numel_of(scales.dims); + if (indices_numel != num_indices || + indices.nbytes != indices_numel * sizeof(int32_t) || + weight.nbytes != weight_numel || + scales.nbytes != scales_numel * sizeof(float) || + out.nbytes != out_numel * sizeof(float)) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: dtype/byte-size mismatch " + "(indices int32, weight uint8, scales/out fp32)"); + } + if (total_blocks > UINT32_MAX) { + throw std::runtime_error( + "WebGPU embedding_q4gsw: total_blocks exceeds uint32 dispatch range"); + } + + // 1D dispatch: one thread per 32-dim block; validate before any alloc. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kEmbeddingQ4gswWorkgroupSizeX); + const uint32_t workgroup_count = utils::compute_1d_workgroup_count( + device, static_cast(total_blocks), wg_size, "embedding_q4gsw"); + + EmbeddingParams params = {}; + params.embed_dim = embed_dim; + params.blocks_per_row = blocks_per_row; + params.num_indices = num_indices; // std140 layout only; shader derives it + params.group_size = static_cast(group_size); + params.groups_per_row = groups_per_row; + params.bytes_per_row = bytes_per_row; + params.total_blocks = static_cast(total_blocks); + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(EmbeddingParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(EmbeddingParams)); + std::memcpy(mapped, ¶ms, sizeof(EmbeddingParams)); + wgpuBufferUnmap(uniform_buffer); + graph.add_uniform_buffer_bytes(sizeof(EmbeddingParams)); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kEmbeddingQ4gswWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Bind group layout: out (rw) + indices/weight/scales (ro storage) + uniform. + WGPUBindGroupLayoutEntry entries[5] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + for (uint32_t i = 1; i <= 3; i++) { + entries[i].binding = i; + entries[i].visibility = WGPUShaderStage_Compute; + entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + } + entries[4].binding = 4; + entries[4].visibility = WGPUShaderStage_Compute; + entries[4].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 5; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[5] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = out.buffer; + bg_entries[0].size = out.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = indices.buffer; + bg_entries[1].size = indices.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = weight.buffer; + bg_entries[2].size = weight.nbytes; + bg_entries[3].binding = 3; + bg_entries[3].buffer = scales.buffer; + bg_entries[3].size = scales.nbytes; + bg_entries[4].binding = 4; + bg_entries[4].buffer = uniform_buffer; + bg_entries[4].size = sizeof(EmbeddingParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 5; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch( + {pipeline, bind_group, workgroup_count, "embedding_q4gsw"}); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + wgpuBufferRelease(uniform_buffer); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.embedding_q4gsw.default, embedding_q4gsw_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl new file mode 100644 index 00000000000..f16f3760d1c --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw.wgsl @@ -0,0 +1,50 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_indices: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; + +struct Params { + embed_dim: u32, + blocks_per_row: u32, + num_indices: u32, + group_size: u32, + groups_per_row: u32, + bytes_per_row: u32, + total_blocks: u32, + _pad: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per 32-dim block of one gathered row (flat-buffer weight path). +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let block = gid.x; + if (block >= params.total_blocks) { + return; + } + let indices_idx = block / params.blocks_per_row; + let base_dim = (block % params.blocks_per_row) * 32u; + + // token assumed in-range (mirrors Vulkan; no vocab clamp). + let token = u32(t_indices[indices_idx]); + let row_byte_base = token * params.bytes_per_row; + let out_base = indices_idx * params.embed_dim + base_dim; + + for (var t: u32 = 0u; t < 32u; t = t + 1u) { + let dim = base_dim + t; + let byte_idx = row_byte_base + (dim >> 1u); + let word = t_weight[byte_idx >> 2u]; + let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + var nib: u32; + if ((dim & 1u) == 0u) { + nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + } else { + nib = b & 0x0Fu; // odd dim -> low nibble + } + let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] + let scale = t_scales[token * params.groups_per_row + dim / params.group_size]; + t_out[out_base + t] = q * scale; + } +} diff --git a/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h new file mode 100644 index 00000000000..e44c06a2ac5 --- /dev/null +++ b/backends/webgpu/runtime/ops/embedding_q4gsw/embedding_q4gsw_wgsl.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from embedding_q4gsw.wgsl - DO NOT EDIT. +// wgsl-sha256: 1fec9ed315696a88bb7db6c16454fc80e08ff73b0e39720b54515fda4ee1ef7c +inline constexpr const char* kEmbeddingQ4gswWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_indices: array; +@group(0) @binding(2) var t_weight: array; +@group(0) @binding(3) var t_scales: array; + +struct Params { + embed_dim: u32, + blocks_per_row: u32, + num_indices: u32, + group_size: u32, + groups_per_row: u32, + bytes_per_row: u32, + total_blocks: u32, + _pad: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per 32-dim block of one gathered row (flat-buffer weight path). +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let block = gid.x; + if (block >= params.total_blocks) { + return; + } + let indices_idx = block / params.blocks_per_row; + let base_dim = (block % params.blocks_per_row) * 32u; + + // token assumed in-range (mirrors Vulkan; no vocab clamp). + let token = u32(t_indices[indices_idx]); + let row_byte_base = token * params.bytes_per_row; + let out_base = indices_idx * params.embed_dim + base_dim; + + for (var t: u32 = 0u; t < 32u; t = t + 1u) { + let dim = base_dim + t; + let byte_idx = row_byte_base + (dim >> 1u); + let word = t_weight[byte_idx >> 2u]; + let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; + var nib: u32; + if ((dim & 1u) == 0u) { + nib = (b >> 4u) & 0x0Fu; // even dim -> high nibble + } else { + nib = b & 0x0Fu; // odd dim -> low nibble + } + let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] + let scale = t_scales[token * params.groups_per_row + dim / params.group_size]; + t_out[out_base + t] = q * scale; + } +} +)"; + +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeX = 64; +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeY = 1; +inline constexpr uint32_t kEmbeddingQ4gswWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu From 56de95a8f693d0cf4e83cb1d29dbc67ee4f354c5 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:07 -0700 Subject: [PATCH 2/6] [ExecuTorch][WebGPU] et_vk.embedding_q4gsw test suite (export + native golden) Pull Request resolved: https://github.com/pytorch/executorch/pull/20289 Splits the `et_vk.embedding_q4gsw` tests into their own diff (op below, tests above), matching the `sdpa`/`update_cache`/`linear_q4gsw` op+tests convention, and brings them to the same rigor: a multi-shape config sweep run on-device, an fp64 dual-oracle, and required-when-present gating. ghstack-source-id: 395549281 @exported-using-ghexport Differential Revision: [D108668383](https://our.internmc.facebook.com/intern/diff/D108668383/) --- .../webgpu/scripts/test_webgpu_native_ci.sh | 18 ++ .../test/ops/embedding_q4gsw/__init__.py | 5 + .../embedding_q4gsw/test_embedding_q4gsw.py | 161 ++++++++++++++++++ backends/webgpu/test/test_webgpu_native.cpp | 141 ++++++++++++++- 4 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 backends/webgpu/test/ops/embedding_q4gsw/__init__.py create mode 100644 backends/webgpu/test/ops/embedding_q4gsw/test_embedding_q4gsw.py diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index ba6a48c62be..e4cf460f13f 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -45,12 +45,24 @@ DISPATCH_ORDER_DIR="/tmp/dispatch_order" DISPATCH_ORDER_OK=1 UPDATE_CACHE_DIR="/tmp/update_cache" UPDATE_CACHE_OK=1 +EMBEDDING_MODEL="/tmp/webgpu_embedding_q4gsw.pte" +EMBEDDING_INDICES="/tmp/webgpu_embedding_q4gsw_indices.bin" +EMBEDDING_GOLDEN="/tmp/webgpu_embedding_q4gsw_golden.bin" +EMBEDDING_LLAMA1B_MODEL="/tmp/webgpu_embedding_q4gsw_llama1b.pte" +EMBEDDING_LLAMA1B_INDICES="/tmp/webgpu_embedding_q4gsw_llama1b_indices.bin" +EMBEDDING_LLAMA1B_GOLDEN="/tmp/webgpu_embedding_q4gsw_llama1b_golden.bin" $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import export_all_quantized_linear_models export_all_quantized_linear_models('/tmp') " || echo "WARN: q4gsw export failed; required configs will FAIL in webgpu_native_test" +$PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.embedding_q4gsw.test_embedding_q4gsw import export_embedding_q4gsw_model +export_embedding_q4gsw_model('${EMBEDDING_MODEL}', '${EMBEDDING_GOLDEN}', '${EMBEDDING_INDICES}') +export_embedding_q4gsw_model('${EMBEDDING_LLAMA1B_MODEL}', '${EMBEDDING_LLAMA1B_GOLDEN}', '${EMBEDDING_LLAMA1B_INDICES}', 'llama1b') +" || echo "WARN: embedding_q4gsw export failed; embedding configs will FAIL in webgpu_native_test" + $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.dispatch_order.test_dispatch_order import export_dispatch_order_cases export_dispatch_order_cases('${DISPATCH_ORDER_DIR}') @@ -136,6 +148,12 @@ if [[ -x "${BIN_DIR}/webgpu_native_test" ]] && "${PYTHON_EXECUTABLE}" -c "import executorch" 2>/dev/null; then env WEBGPU_TEST_SDPA_DIR=/tmp/ \ WEBGPU_TEST_QUANTIZED_LINEAR_DIR=/tmp/ \ + WEBGPU_TEST_EMBEDDING_Q4GSW_MODEL="${EMBEDDING_MODEL}" \ + WEBGPU_TEST_EMBEDDING_Q4GSW_INDICES="${EMBEDDING_INDICES}" \ + WEBGPU_TEST_EMBEDDING_Q4GSW_GOLDEN="${EMBEDDING_GOLDEN}" \ + WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_MODEL="${EMBEDDING_LLAMA1B_MODEL}" \ + WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_INDICES="${EMBEDDING_LLAMA1B_INDICES}" \ + WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_GOLDEN="${EMBEDDING_LLAMA1B_GOLDEN}" \ "${BIN_DIR}/webgpu_native_test" else echo "(skipping webgpu_native_test: executorch wheel absent — exports did not run)" diff --git a/backends/webgpu/test/ops/embedding_q4gsw/__init__.py b/backends/webgpu/test/ops/embedding_q4gsw/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/webgpu/test/ops/embedding_q4gsw/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/webgpu/test/ops/embedding_q4gsw/test_embedding_q4gsw.py b/backends/webgpu/test/ops/embedding_q4gsw/test_embedding_q4gsw.py new file mode 100644 index 00000000000..120e987478b --- /dev/null +++ b/backends/webgpu/test/ops/embedding_q4gsw/test_embedding_q4gsw.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""4-bit groupwise-symmetric quantized embedding (`et_vk.embedding_q4gsw`) export ++ golden for the WebGPU backend. + +Quantizes an nn.Embedding with the Llama EmbeddingQuantHandler recipe (int4 +groupwise-symmetric, packed) which lowers to `quantized_decomposed.embedding_4bit` +and fuses under VulkanPartitioner into `et_vk.embedding_q4gsw.default` +(is_linear_weight=False). Writes a torch-computed golden (the native binary has no +ATen) via the registered et_vk reference op + the raw int32 indices the native +test loads and compares. + +Two shapes are exercised: a tiny one and a Llama-3.2-1B-scale one (EMBED=2048, +GROUP=64) so the per-group scale indexing (32 groups/row) + dequant are validated +at the real embedding dim, not just a single 64-wide row. +""" + +import unittest +from collections import namedtuple + +import executorch.backends.vulkan.custom_ops_lib # noqa: F401 + +import torch +from executorch.backends.vulkan import VulkanPartitioner +from executorch.examples.models.llama.source_transformation.quantize import ( + EmbeddingQuantHandler, +) +from executorch.exir import to_edge_transform_and_lower + +# vocab rows, embed columns (embed % 32 == 0), group-wise scales, gather indices. +Shape = namedtuple("Shape", ["name", "vocab", "embed", "group", "indices"]) +SHAPES = [ + Shape("small", 64, 64, 32, [1, 5, 63, 0]), + # Llama-3.2-1B embedding dim + group (small vocab keeps the export light). + Shape("llama1b", 512, 2048, 64, [1, 5, 511, 0]), +] + + +class _EmbeddingModel(torch.nn.Module): + def __init__(self, vocab: int, embed: int) -> None: + super().__init__() + self.emb = torch.nn.Embedding(vocab, embed) + + def forward(self, idx: torch.Tensor) -> torch.Tensor: + return self.emb(idx) + + +def _make_quantized_model(shape: Shape) -> torch.nn.Module: + torch.manual_seed(0) + return ( + EmbeddingQuantHandler( + _EmbeddingModel(shape.vocab, shape.embed).eval(), + device="cpu", + bitwidth=4, + group_size=shape.group, + packed=True, + quantize_with_hqq=False, + ) + .quantized_model() + .eval() + ) + + +def _indices(shape: Shape) -> torch.Tensor: + return torch.tensor(shape.indices, dtype=torch.long) + + +def _quant_params(qm: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor, int]: + sd = qm.state_dict() + weight = next( + v for k, v in sd.items() if k.endswith("weight") and v.dtype == torch.uint8 + ) + scales = next(v for k, v in sd.items() if k.endswith("scales")) + if scales.ndim == 1: + scales = scales.unsqueeze(1) + embed = weight.shape[1] * 2 + group_size = embed // scales.shape[1] + return weight, scales, group_size + + +def _golden(qm: torch.nn.Module, idx: torch.Tensor) -> torch.Tensor: + # Reference = the registered et_vk dequant+gather op (non-linear branch). + weight, scales, group_size = _quant_params(qm) + return torch.ops.et_vk.embedding_q4gsw.default( + weight, scales, group_size, idx, False + ) + + +def _export(qm: torch.nn.Module, idx: torch.Tensor): + ep = torch.export.export(qm, (idx,)) + return to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + +class TestEmbeddingQ4gsw(unittest.TestCase): + def test_export_delegates(self) -> None: + for shape in SHAPES: + with self.subTest(shape=shape.name): + et = _export(_make_quantized_model(shape), _indices(shape)) + found = any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ) + self.assertTrue( + found, "Expected a VulkanBackend delegate (embedding_q4gsw fusion)" + ) + + def test_golden_matches_eager(self) -> None: + # The torch golden (et_vk reference) must equal torch dequant+gather, so a + # buggy golden can't fake-pass the native kernel. Run at both shapes so the + # Llama-scale per-group scale indexing (32 groups/row) is covered. + for shape in SHAPES: + with self.subTest(shape=shape.name): + qm = _make_quantized_model(shape) + idx = _indices(shape) + weight, scales, group_size = _quant_params(qm) + vocab = weight.shape[0] + embed = weight.shape[1] * 2 + # fp64 reference dequant, vectorized (no fp32 rounding in oracle). + w = weight.to(torch.int64) + nib = torch.empty((vocab, embed), dtype=torch.int64) + nib[:, 0::2] = (w >> 4) & 0xF # even dim -> high nibble + nib[:, 1::2] = w & 0xF # odd dim -> low nibble + scale_exp = scales.to(torch.float64).repeat_interleave( + group_size, dim=1 + ) + deq = (nib - 8).to(torch.float64) * scale_exp + eager = torch.nn.functional.embedding(idx, deq) + golden = _golden(qm, idx) + torch.testing.assert_close(golden.double(), eager, atol=1e-5, rtol=1e-5) + + +def export_embedding_q4gsw_model( + pte_path: str, golden_path: str, indices_path: str, shape_name: str = "small" +) -> None: + """Write the embedding_q4gsw .pte + torch golden (raw LE fp32) + raw LE int32 + indices (downcast from int64 for the int32-typed GPU buffer). `shape_name` + selects an entry from SHAPES (default the tiny shape; "llama1b" = EMBED=2048).""" + shape = next(s for s in SHAPES if s.name == shape_name) + qm = _make_quantized_model(shape) + idx = _indices(shape) + golden = _golden(qm, idx).detach().numpy().astype("((i % 17) - 8) / 16.0f; } -// Per-element dual tolerance (abs OR rel), parameterized like sdpa_within_tol. +// Fwd decl of the per-element abs-OR-rel tolerance helper (defined below). +static bool quant_within_tol( + const float* out, + const float* golden, + int n, + float atol, + float rtol, + float* ma, + float* mr); + +static std::vector load_indices( + const std::string& path, + size_t numel) { + // Load raw little-endian int32 indices written by the export .py. + std::vector g(numel); + FILE* f = std::fopen(path.c_str(), "rb"); + if (!f) { + return {}; + } + size_t n = std::fread(g.data(), sizeof(int32_t), numel, f); + std::fclose(f); + if (n != numel) { + return {}; + } + return g; +} + +static bool test_embedding_q4gsw( + const std::string& model_path, + const std::string& indices_path, + const std::string& golden_path, + int num_indices, + int embed, + const char* label) { + // q4gsw embedding-gather vs torch golden; shapes per test_embedding_q4gsw.py. + const int out_numel = num_indices * embed; + printf( + "\n--- Test: embedding_q4gsw (%s: indices=%d, embed=%d) ---\n", + label, + num_indices, + embed); + + Module module(model_path); + auto err = module.load_forward(); + if (err != Error::Ok) { + printf("FAIL: could not load forward method (error %d)\n", (int)err); + return false; + } + printf("Model loaded: %s\n", model_path.c_str()); + + std::vector idx32 = load_indices(indices_path, num_indices); + std::vector golden = load_golden(golden_path, out_numel); + if (idx32.empty() || golden.empty()) { + printf( + "FAIL: could not load indices %s / golden %s\n", + indices_path.c_str(), + golden_path.c_str()); + return false; + } + + // int64 at the program boundary; copy_inputs narrows to the int32 buffer. + std::vector idx64(idx32.begin(), idx32.end()); + auto idx = make_tensor_ptr({num_indices}, std::move(idx64)); + + auto result = module.forward({EValue(idx)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + const auto& out_tensor = outputs[0].toTensor(); + if (out_tensor.numel() != out_numel) { + printf( + "FAIL: output numel %zu != expected %d\n", + (size_t)out_tensor.numel(), + out_numel); + return false; + } + const float* out_data = out_tensor.const_data_ptr(); + + float max_abs_err = 0.0f, max_rel_err = 0.0f; + const bool pass = quant_within_tol( + out_data, + golden.data(), + out_numel, + 1e-3f, + 1e-3f, + &max_abs_err, + &max_rel_err); + printf( + "Max abs error: %e Max rel error: %e (checked %d elements)\n", + max_abs_err, + max_rel_err, + out_numel); + if (!pass) { + printf("FAIL: embedding_q4gsw exceeds tolerance 1e-3 (abs AND rel)\n"); + return false; + } + printf("PASS: embedding_q4gsw test\n"); + return true; +} + static bool quant_within_tol( const float* out, const float* golden, @@ -1342,6 +1447,31 @@ int main(int argc, char** argv) { } } + // embedding_q4gsw on-GPU configs: small + llama1b (env-gated, + // run-if-present). + struct EmbConfig { + const char* name; + const char* model_env; + const char* indices_env; + const char* golden_env; + int num_indices; + int embed; + }; + const EmbConfig emb_configs[] = { + {"small", + "WEBGPU_TEST_EMBEDDING_Q4GSW_MODEL", + "WEBGPU_TEST_EMBEDDING_Q4GSW_INDICES", + "WEBGPU_TEST_EMBEDDING_Q4GSW_GOLDEN", + 4, + 64}, + {"llama1b", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_MODEL", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_INDICES", + "WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_GOLDEN", + 4, + 2048}, + }; + // SDPA sweep: configs self-discover their sdpa_.pte/.golden.bin under // this directory (default "" = the embedded-file root / cwd). Set // WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/). @@ -1389,6 +1519,15 @@ int main(int argc, char** argv) { ok = false; } + for (const auto& c : emb_configs) { + const char* m = std::getenv(c.model_env); + const char* ip = std::getenv(c.indices_env); + const char* g = std::getenv(c.golden_env); + if (m && ip && g && *m && *ip && *g) { + ok = test_embedding_q4gsw(m, ip, g, c.num_indices, c.embed, c.name) && ok; + } + } + bool sdpa_ran = false; bool sdpa_ok = test_sdpa_sweep(sdpa_dir, &sdpa_ran); if (sdpa_ran) { From 96289c189220c38d6901b28c6760cf8d611b625f Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:07 -0700 Subject: [PATCH 3/6] [ExecuTorch][WebGPU] Add et_vk.apply_rotary_emb (interleaved RoPE) + ValueList multi-output Pull Request resolved: https://github.com/pytorch/executorch/pull/20264 Adds the WebGPU backend handler for `et_vk.apply_rotary_emb.default` (interleaved Llama rotary positional embedding) plus the `ValueList` graph-value support its multi-output signature requires. The op rotates the query and key tensors by a shared `freqs_cos`/`freqs_sin` pair and is composed of two dispatches of one WGSL kernel: each thread handles one (even, odd) element pair of a head row (`out[2i] = x[2i]*cos - x[2i+1]*sin`, `out[2i+1] = x[2i]*sin + x[2i+1]*cos`), one dispatch writing `xq_out` and one writing `xk_out`, mirroring the Vulkan `apply_rotary_emb` reference (buffer-only, fp32, the interleaved `.default` variant). Each dispatch owns a distinct compute pipeline (the graph destructor releases per dispatch, so a shared handle would double-free); the workgroup size is a `wg_size` pipeline-override constant clamped to the device limit, both 1D dispatch counts go through `WebGPUUtils::compute_1d_workgroup_count` and are validated before any GPU-object allocation, and the embedded WGSL header is generated by `gen_wgsl_headers.py`. The two outputs (`xq_out`, `xk_out`) are serialized by the Vulkan exporter as a single `ValueList` graph value, which the runtime did not previously model. This adds the `ValueType::ValueList` value kind, a `value_lists_` table populated during `build()`, and a `get_value_list` accessor the handler uses to resolve the output ids. While in that code path it also closes a latent gap: a constant tensor whose `constant_id` is set but whose constants table is missing or out of range now throws (fail-loud) rather than silently leaving the buffer uninitialized. ghstack-source-id: 395549282 @exported-using-ghexport Differential Revision: [D108428756](https://our.internmc.facebook.com/intern/diff/D108428756/) --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUGraph.cpp | 19 ++ backends/webgpu/runtime/WebGPUGraph.h | 16 +- .../runtime/ops/rope/RotaryEmbedding.cpp | 288 ++++++++++++++++++ .../runtime/ops/rope/rotary_embedding.wgsl | 46 +++ .../runtime/ops/rope/rotary_embedding_wgsl.h | 70 +++++ 6 files changed, 439 insertions(+), 1 deletion(-) create mode 100644 backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp create mode 100644 backends/webgpu/runtime/ops/rope/rotary_embedding.wgsl create mode 100644 backends/webgpu/runtime/ops/rope/rotary_embedding_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 1eb75128997..8bf1674d872 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -40,6 +40,7 @@ set(WEBGPU_SRCS runtime/ops/quantized_linear/QuantizedLinear.cpp runtime/ops/mul/BinaryOp.cpp runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp + runtime/ops/rope/RotaryEmbedding.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index d3974eab194..65aaaf6c681 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -239,6 +239,7 @@ void WebGPUGraph::build( ints_.resize(num_vals, 0); doubles_.resize(num_vals, 0.0); bools_.resize(num_vals, false); + value_lists_.resize(num_vals); for (int i = 0; i < num_vals; i++) { const auto* val = values->Get(i); @@ -313,7 +314,15 @@ void WebGPUGraph::build( throw std::runtime_error( "WebGPU: constant has no inline offset and no named-data key"); } + } else { + throw std::runtime_error( + "WebGPU: constant_id set but the constants table is missing " + "or the id is out of range"); } + } else if (constant_id >= 0 && tensor.nbytes > 0) { + // constant_id set but constant_data null -> fail loud. + throw std::runtime_error( + "WebGPU: constant_id set but constant_data is null"); } } else { // Shared buffer: track required size, defer allocation to pass 2 @@ -363,6 +372,16 @@ void WebGPUGraph::build( add_uniform_buffer_bytes(kSymIntUniformBytes); break; } + case vkgraph::GraphTypes::ValueList: { + value_types_[i] = ValueType::ValueList; + const auto* items = val->value_as_ValueList()->items(); + if (items) { + for (unsigned j = 0; j < items->size(); j++) { + value_lists_[i].push_back(static_cast(items->Get(j))); + } + } + break; + } default: value_types_[i] = ValueType::Null; break; diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index 5bd5b93b524..a914c8710ce 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -119,6 +119,10 @@ class WebGPUGraph { bool get_bool(int id) const { return bools_[id]; } + // Member value ids of a serialized ValueList (op multi-output list). + const std::vector& get_value_list(int id) const { + return value_lists_[id]; + } // Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO. // set_symint writes the buffer + marks dirty only if the value changed. @@ -215,7 +219,16 @@ class WebGPUGraph { return static_cast(value_types_.size()); } - enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt }; + enum class ValueType { + Tensor, + Int, + Double, + Bool, + Null, + String, + SymInt, + ValueList + }; ValueType get_value_type(int id) const { return value_types_[id]; @@ -233,6 +246,7 @@ class WebGPUGraph { std::vector ints_; std::vector doubles_; std::vector bools_; + std::vector> value_lists_; // SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse. struct SymIntSlot { diff --git a/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp b/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp new file mode 100644 index 00000000000..cf4fa0a1ca2 --- /dev/null +++ b/backends/webgpu/runtime/ops/rope/RotaryEmbedding.cpp @@ -0,0 +1,288 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Uniform layout matching the WGSL Params struct (16-byte aligned, 32 bytes). +struct RotaryParams { + uint32_t n_heads; + uint32_t seq; + uint32_t head_dim; + uint32_t half_dim; + uint32_t num_pairs; + uint32_t _pad0; + uint32_t _pad1; + uint32_t _pad2; +}; +static_assert(sizeof(RotaryParams) == 32, "RotaryParams must be 32 bytes"); + +uint64_t numel_of(const std::vector& dims) { + uint64_t n = 1; + for (int64_t d : dims) { + n *= static_cast(d); + } + return n; +} + +// Rotate one (x->out) with the shared shader; freqs shared between xq and xk. +void add_rope_dispatch( + WebGPUGraph& graph, + WGPUDevice device, + WGPUComputePipeline pipeline, + WGPUBindGroupLayout bgl, + const WebGPUTensor& x, + const WebGPUTensor& out, + const WebGPUTensor& freqs_cos, + const WebGPUTensor& freqs_sin, + uint32_t n_heads, + uint32_t seq, + uint32_t head_dim, + uint32_t workgroup_count) { + const uint32_t half_dim = head_dim / 2u; + // out.dims == in.dims (asserted in impl), so this matches the caller's wgc. + const uint32_t num_pairs = static_cast(numel_of(out.dims) / 2u); + + RotaryParams params = {}; + params.n_heads = n_heads; + params.seq = seq; + params.head_dim = head_dim; + params.half_dim = half_dim; + params.num_pairs = num_pairs; + + WGPUBufferDescriptor uniform_desc = {}; + uniform_desc.size = sizeof(RotaryParams); + uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst; + uniform_desc.mappedAtCreation = true; + WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc); + void* mapped = + wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RotaryParams)); + std::memcpy(mapped, ¶ms, sizeof(RotaryParams)); + wgpuBufferUnmap(uniform_buffer); + graph.add_uniform_buffer_bytes(sizeof(RotaryParams)); + + WGPUBindGroupEntry bg_entries[5] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = out.buffer; + bg_entries[0].size = out.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = x.buffer; + bg_entries[1].size = x.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = freqs_cos.buffer; + bg_entries[2].size = freqs_cos.nbytes; + bg_entries[3].binding = 3; + bg_entries[3].buffer = freqs_sin.buffer; + bg_entries[3].size = freqs_sin.nbytes; + bg_entries[4].binding = 4; + bg_entries[4].buffer = uniform_buffer; + bg_entries[4].size = sizeof(RotaryParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 5; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch( + {pipeline, bind_group, workgroup_count, "apply_rotary_emb"}); + + wgpuBufferRelease(uniform_buffer); +} + +// args: [xq, xk, freqs_cos, freqs_sin, out_list(ValueList[xq_out, xk_out])]. +void apply_rotary_emb_impl(WebGPUGraph& graph, const std::vector& args) { + const int xq_id = args.at(0); + const int xk_id = args.at(1); + const int freqs_cos_id = args.at(2); + const int freqs_sin_id = args.at(3); + + const std::vector& out_list = graph.get_value_list(args.at(4)); + if (out_list.size() != 2) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: expected an output ValueList of size 2"); + } + + WGPUDevice device = graph.device(); + + const auto& xq = graph.get_tensor(xq_id); + const auto& xk = graph.get_tensor(xk_id); + const auto& freqs_cos = graph.get_tensor(freqs_cos_id); + const auto& freqs_sin = graph.get_tensor(freqs_sin_id); + const auto& xq_out = graph.get_tensor(out_list[0]); + const auto& xk_out = graph.get_tensor(out_list[1]); + + // Vulkan shape contract: xq/xk (B,S,n_heads,head_dim), freqs (S,head_dim/2). + if (xq.dims.size() < 3 || xk.dims.size() < 3 || freqs_cos.dims.size() < 2) { + throw std::runtime_error("WebGPU apply_rotary_emb: malformed dims"); + } + const uint32_t head_dim = static_cast(xq.dims.back()); + const uint32_t seq = static_cast(xq.dims[xq.dims.size() - 3]); + const uint32_t n_heads_q = static_cast(xq.dims[xq.dims.size() - 2]); + const uint32_t n_heads_k = static_cast(xk.dims[xk.dims.size() - 2]); + const uint32_t seq_k = static_cast(xk.dims[xk.dims.size() - 3]); + const uint32_t half_dim = static_cast(freqs_cos.dims.back()); + + if (head_dim == 0 || head_dim % 2 != 0) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: head_dim must be a nonzero multiple of 2"); + } + if (static_cast(xk.dims.back()) != head_dim || seq_k != seq) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: xq/xk head_dim and seq must match"); + } + if (half_dim * 2u != head_dim) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: head_dim != 2 * freqs_cos last dim"); + } + if (freqs_cos.dims != freqs_sin.dims) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: freqs_cos and freqs_sin shapes differ"); + } + + if (xq.buffer == nullptr || xk.buffer == nullptr || + freqs_cos.buffer == nullptr || freqs_sin.buffer == nullptr || + xq_out.buffer == nullptr || xk_out.buffer == nullptr) { + throw std::runtime_error("WebGPU apply_rotary_emb: null buffer binding"); + } + + // All tensors are fp32; output shapes equal their inputs. + const uint64_t xq_numel = numel_of(xq.dims); + const uint64_t xk_numel = numel_of(xk.dims); + const uint64_t freqs_numel = numel_of(freqs_cos.dims); + if (freqs_numel != static_cast(seq) * half_dim || + xq.nbytes != xq_numel * sizeof(float) || + xk.nbytes != xk_numel * sizeof(float) || + freqs_cos.nbytes != freqs_numel * sizeof(float) || + freqs_sin.nbytes != freqs_numel * sizeof(float) || + xq_out.nbytes != xq_numel * sizeof(float) || + xk_out.nbytes != xk_numel * sizeof(float)) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: dtype/byte-size mismatch (all fp32) or " + "freqs shape != [seq, head_dim/2]"); + } + + if (xq_numel / 2u > UINT32_MAX || xk_numel / 2u > UINT32_MAX) { + throw std::runtime_error( + "WebGPU apply_rotary_emb: pair count exceeds uint32 dispatch range"); + } + + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kRotaryEmbeddingWorkgroupSizeX); + // Validate both dispatches before any GPU-object alloc (no leak on throw). + const uint32_t xq_wgc = utils::compute_1d_workgroup_count( + device, + static_cast(xq_numel / 2u), + wg_size, + "apply_rotary_emb"); + const uint32_t xk_wgc = utils::compute_1d_workgroup_count( + device, + static_cast(xk_numel / 2u), + wg_size, + "apply_rotary_emb"); + + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kRotaryEmbeddingWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + // Bind group: out (rw) + in/freqs_cos/freqs_sin (ro) + uniform. + WGPUBindGroupLayoutEntry entries[5] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_Storage; + for (uint32_t i = 1; i <= 3; i++) { + entries[i].binding = i; + entries[i].visibility = WGPUShaderStage_Compute; + entries[i].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + } + entries[4].binding = 4; + entries[4].visibility = WGPUShaderStage_Compute; + entries[4].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 5; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + // One pipeline per dispatch; a shared handle would double-free. + WGPUComputePipeline pipeline_q = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + WGPUComputePipeline pipeline_k = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + add_rope_dispatch( + graph, + device, + pipeline_q, + bgl, + xq, + xq_out, + freqs_cos, + freqs_sin, + n_heads_q, + seq, + head_dim, + xq_wgc); + add_rope_dispatch( + graph, + device, + pipeline_k, + bgl, + xk, + xk_out, + freqs_cos, + freqs_sin, + n_heads_k, + seq, + head_dim, + xk_wgc); + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // pipeline_q/pipeline_k owned by their dispatches; graph dtor frees. +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/rope/rotary_embedding.wgsl b/backends/webgpu/runtime/ops/rope/rotary_embedding.wgsl new file mode 100644 index 00000000000..11c52b2a6db --- /dev/null +++ b/backends/webgpu/runtime/ops/rope/rotary_embedding.wgsl @@ -0,0 +1,46 @@ +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_freqs_cos: array; +@group(0) @binding(3) var t_freqs_sin: array; + +struct Params { + n_heads: u32, + seq: u32, + head_dim: u32, + half_dim: u32, + num_pairs: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per (even,odd) pair; interleaved Llama RoPE, shared xq/xk shader. +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let pair = gid.x; + if (pair >= params.num_pairs) { + return; + } + let half_dim = params.half_dim; + let pair_i = pair % half_dim; + let t1 = pair / half_dim; + let head = t1 % params.n_heads; + let t2 = t1 / params.n_heads; + let s = t2 % params.seq; + let b = t2 / params.seq; + + let base = + (((b * params.seq + s) * params.n_heads + head) * params.head_dim) + + 2u * pair_i; + let freqs_idx = s * half_dim + pair_i; + + let c = t_freqs_cos[freqs_idx]; + let si = t_freqs_sin[freqs_idx]; + let x_r = t_in[base]; + let x_i = t_in[base + 1u]; + t_out[base] = x_r * c - x_i * si; + t_out[base + 1u] = x_r * si + x_i * c; +} diff --git a/backends/webgpu/runtime/ops/rope/rotary_embedding_wgsl.h b/backends/webgpu/runtime/ops/rope/rotary_embedding_wgsl.h new file mode 100644 index 00000000000..b369fe9cdfb --- /dev/null +++ b/backends/webgpu/runtime/ops/rope/rotary_embedding_wgsl.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from rotary_embedding.wgsl - DO NOT EDIT. +// wgsl-sha256: c60f1ce1c214864bf577617e560404e8b4cc6750c3e96874559ab6bfc1f17ad6 +inline constexpr const char* kRotaryEmbeddingWGSL = R"( +@group(0) @binding(0) var t_out: array; +@group(0) @binding(1) var t_in: array; +@group(0) @binding(2) var t_freqs_cos: array; +@group(0) @binding(3) var t_freqs_sin: array; + +struct Params { + n_heads: u32, + seq: u32, + head_dim: u32, + half_dim: u32, + num_pairs: u32, + _pad0: u32, + _pad1: u32, + _pad2: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +// One thread per (even,odd) pair; interleaved Llama RoPE, shared xq/xk shader. +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let pair = gid.x; + if (pair >= params.num_pairs) { + return; + } + let half_dim = params.half_dim; + let pair_i = pair % half_dim; + let t1 = pair / half_dim; + let head = t1 % params.n_heads; + let t2 = t1 / params.n_heads; + let s = t2 % params.seq; + let b = t2 / params.seq; + + let base = + (((b * params.seq + s) * params.n_heads + head) * params.head_dim) + + 2u * pair_i; + let freqs_idx = s * half_dim + pair_i; + + let c = t_freqs_cos[freqs_idx]; + let si = t_freqs_sin[freqs_idx]; + let x_r = t_in[base]; + let x_i = t_in[base + 1u]; + t_out[base] = x_r * c - x_i * si; + t_out[base + 1u] = x_r * si + x_i * c; +} +)"; + +inline constexpr uint32_t kRotaryEmbeddingWorkgroupSizeX = 64; +inline constexpr uint32_t kRotaryEmbeddingWorkgroupSizeY = 1; +inline constexpr uint32_t kRotaryEmbeddingWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu From 200c506cf10abf7dbb9f6ad6c5d3d7f465d4dc8a Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:08 -0700 Subject: [PATCH 4/6] [ExecuTorch][WebGPU] et_vk.apply_rotary_emb test suite (export + native golden) Pull Request resolved: https://github.com/pytorch/executorch/pull/20290 Splits the `et_vk.apply_rotary_emb` tests into their own diff (op below, tests above), matching the `sdpa`/`update_cache`/`linear_q4gsw` convention, and brings them to the same rigor: a multi-shape config sweep run on-device (prefill + decode) and a library dual-oracle at both shapes. ghstack-source-id: 395549287 @exported-using-ghexport Differential Revision: [D108668384](https://our.internmc.facebook.com/intern/diff/D108668384/) --- .../webgpu/scripts/test_webgpu_native_ci.sh | 18 +++ backends/webgpu/test/ops/rope/__init__.py | 5 + backends/webgpu/test/ops/rope/test_rope.py | 129 +++++++++++++++ backends/webgpu/test/test_webgpu_native.cpp | 151 ++++++++++++++++++ 4 files changed, 303 insertions(+) create mode 100644 backends/webgpu/test/ops/rope/__init__.py create mode 100644 backends/webgpu/test/ops/rope/test_rope.py diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index e4cf460f13f..100e48dfbfd 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -51,6 +51,12 @@ EMBEDDING_GOLDEN="/tmp/webgpu_embedding_q4gsw_golden.bin" EMBEDDING_LLAMA1B_MODEL="/tmp/webgpu_embedding_q4gsw_llama1b.pte" EMBEDDING_LLAMA1B_INDICES="/tmp/webgpu_embedding_q4gsw_llama1b_indices.bin" EMBEDDING_LLAMA1B_GOLDEN="/tmp/webgpu_embedding_q4gsw_llama1b_golden.bin" +ROPE_MODEL="/tmp/webgpu_rope.pte" +ROPE_XQ_GOLDEN="/tmp/webgpu_rope_xq_golden.bin" +ROPE_XK_GOLDEN="/tmp/webgpu_rope_xk_golden.bin" +ROPE_DECODE_MODEL="/tmp/webgpu_rope_decode.pte" +ROPE_DECODE_XQ_GOLDEN="/tmp/webgpu_rope_decode_xq_golden.bin" +ROPE_DECODE_XK_GOLDEN="/tmp/webgpu_rope_decode_xk_golden.bin" $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import export_all_quantized_linear_models @@ -63,6 +69,12 @@ export_embedding_q4gsw_model('${EMBEDDING_MODEL}', '${EMBEDDING_GOLDEN}', '${EMB export_embedding_q4gsw_model('${EMBEDDING_LLAMA1B_MODEL}', '${EMBEDDING_LLAMA1B_GOLDEN}', '${EMBEDDING_LLAMA1B_INDICES}', 'llama1b') " || echo "WARN: embedding_q4gsw export failed; embedding configs will FAIL in webgpu_native_test" +$PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.rope.test_rope import export_rope_model +export_rope_model('${ROPE_MODEL}', '${ROPE_XQ_GOLDEN}', '${ROPE_XK_GOLDEN}') +export_rope_model('${ROPE_DECODE_MODEL}', '${ROPE_DECODE_XQ_GOLDEN}', '${ROPE_DECODE_XK_GOLDEN}', 'decode') +" || echo "WARN: rope export failed; apply_rotary_emb configs will FAIL in webgpu_native_test" + $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.dispatch_order.test_dispatch_order import export_dispatch_order_cases export_dispatch_order_cases('${DISPATCH_ORDER_DIR}') @@ -154,6 +166,12 @@ if [[ -x "${BIN_DIR}/webgpu_native_test" ]] && WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_MODEL="${EMBEDDING_LLAMA1B_MODEL}" \ WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_INDICES="${EMBEDDING_LLAMA1B_INDICES}" \ WEBGPU_TEST_EMBEDDING_Q4GSW_LLAMA1B_GOLDEN="${EMBEDDING_LLAMA1B_GOLDEN}" \ + WEBGPU_TEST_ROPE_MODEL="${ROPE_MODEL}" \ + WEBGPU_TEST_ROPE_XQ_GOLDEN="${ROPE_XQ_GOLDEN}" \ + WEBGPU_TEST_ROPE_XK_GOLDEN="${ROPE_XK_GOLDEN}" \ + WEBGPU_TEST_ROPE_DECODE_MODEL="${ROPE_DECODE_MODEL}" \ + WEBGPU_TEST_ROPE_DECODE_XQ_GOLDEN="${ROPE_DECODE_XQ_GOLDEN}" \ + WEBGPU_TEST_ROPE_DECODE_XK_GOLDEN="${ROPE_DECODE_XK_GOLDEN}" \ "${BIN_DIR}/webgpu_native_test" else echo "(skipping webgpu_native_test: executorch wheel absent — exports did not run)" diff --git a/backends/webgpu/test/ops/rope/__init__.py b/backends/webgpu/test/ops/rope/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/webgpu/test/ops/rope/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/webgpu/test/ops/rope/test_rope.py b/backends/webgpu/test/ops/rope/test_rope.py new file mode 100644 index 00000000000..136beb9d8c8 --- /dev/null +++ b/backends/webgpu/test/ops/rope/test_rope.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Interleaved rotary positional embedding (`et_vk.apply_rotary_emb`) export + +goldens for the WebGPU backend. + +Exports the Llama interleaved RoPE (use_hf_rope=False) with freqs_cos/freqs_sin +as runtime forward inputs (no constant prepack), which fuses under +VulkanPartitioner into `et_vk.apply_rotary_emb.default` (two outputs xq_out, +xk_out serialized as a ValueList). Inputs are deterministic /16 ramps so the +native binary reconstructs them bit-for-bit; the two torch-computed goldens are +written for the native binary to compare (it has no ATen). + +Two shapes are exercised: a multi-token prefill shape and a single-token (S=1) +decode shape at the Llama-3.2-1B head config (GQA 32:8), so the seq=1 / batch +decompositions and the position->freqs indexing are covered at decode too. +""" + +import unittest +from collections import namedtuple + +import executorch.backends.vulkan.custom_ops_lib # noqa: F401 + +import torch +from executorch.backends.vulkan import VulkanPartitioner +from executorch.examples.models.llama.rope import apply_rotary_emb, RotaryEmbedding +from executorch.exir import to_edge_transform_and_lower + +# B batch, S tokens, NH query heads, NKV kv heads (NH != NKV so the two outputs +# are distinguishable by numel), HD head dim (even; HD/2 rotation pairs). +Shape = namedtuple("Shape", ["name", "b", "s", "nh", "nkv", "hd"]) +SHAPES = [ + Shape("multi", 1, 5, 8, 2, 64), + # Single-token decode at Llama-3.2-1B head config (GQA 32:8, head_dim 64). + Shape("decode", 1, 1, 32, 8, 64), +] + + +def _ramp(numel: int, mod: int, off: int) -> torch.Tensor: + # ((i % mod) - off) / 16: exact in fp32, matches test_webgpu_native.cpp. + idx = torch.arange(numel, dtype=torch.int64) + return ((idx % mod) - off).to(torch.float32) / 16.0 + + +def _inputs( + shape: Shape, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + xq = _ramp(shape.b * shape.s * shape.nh * shape.hd, 17, 8).reshape( + shape.b, shape.s, shape.nh, shape.hd + ) + xk = _ramp(shape.b * shape.s * shape.nkv * shape.hd, 13, 6).reshape( + shape.b, shape.s, shape.nkv, shape.hd + ) + freqs_cos = _ramp(shape.s * (shape.hd // 2), 11, 5).reshape(shape.s, shape.hd // 2) + freqs_sin = _ramp(shape.s * (shape.hd // 2), 7, 3).reshape(shape.s, shape.hd // 2) + return xq, xk, freqs_cos, freqs_sin + + +def _golden( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + # Reference = the registered et_vk op the kernel implements. + return torch.ops.et_vk.apply_rotary_emb.default(xq, xk, freqs_cos, freqs_sin) + + +def _export(inputs): + # Export the real Llama RoPE module (not a hand-written copy) so the test + # exercises the same pattern the partitioner matches in production models. + ep = torch.export.export(RotaryEmbedding().eval(), inputs) + return to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + +class TestRope(unittest.TestCase): + def test_export_delegates(self) -> None: + for shape in SHAPES: + with self.subTest(shape=shape.name): + et = _export(_inputs(shape)) + found = any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ) + self.assertTrue( + found, "Expected a VulkanBackend delegate (apply_rotary_emb fusion)" + ) + + def test_golden_matches_eager(self) -> None: + # The et_vk golden must equal the real Llama apply_rotary_emb, so a buggy + # golden can't fake-pass the native kernel. Run at both shapes so the S=1 + # decode position->freqs indexing is covered. + for shape in SHAPES: + with self.subTest(shape=shape.name): + xq, xk, fc, fs = _inputs(shape) + gq, gk = _golden(xq, xk, fc, fs) + eq, ek = apply_rotary_emb(xq, xk, fc, fs) + torch.testing.assert_close(gq, eq, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(gk, ek, atol=1e-5, rtol=1e-5) + + +def export_rope_model( + pte_path: str, xq_golden_path: str, xk_golden_path: str, shape_name: str = "multi" +) -> None: + """Write the apply_rotary_emb .pte + the xq_out and xk_out torch goldens + (raw LE fp32). Inputs are /16 ramps reconstructed in the native test. + `shape_name` selects an entry from SHAPES (default the multi-token shape).""" + shape = next(s for s in SHAPES if s.name == shape_name) + xq, xk, fc, fs = _inputs(shape) + gq, gk = _golden(xq, xk, fc, fs) + et = _export((xq, xk, fc, fs)) + with open(pte_path, "wb") as f: + f.write(et.buffer) + gq.detach().numpy().astype("((i % mod) - off) / 16.0f; + }; + std::vector xq(xq_numel), xk(xk_numel), fc(freqs_numel), + fs(freqs_numel); + for (int i = 0; i < xq_numel; i++) { + xq[i] = ramp(i, 17, 8); + } + for (int i = 0; i < xk_numel; i++) { + xk[i] = ramp(i, 13, 6); + } + for (int i = 0; i < freqs_numel; i++) { + fc[i] = ramp(i, 11, 5); + fs[i] = ramp(i, 7, 3); + } + + auto xqt = make_tensor_ptr({1, S, NH, HD}, std::vector(xq)); + auto xkt = make_tensor_ptr({1, S, NKV, HD}, std::vector(xk)); + auto fct = make_tensor_ptr({S, HD / 2}, std::vector(fc)); + auto fst = make_tensor_ptr({S, HD / 2}, std::vector(fs)); + + auto result = + module.forward({EValue(xqt), EValue(xkt), EValue(fct), EValue(fst)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + const auto& outputs = result.get(); + + // Outputs in graph order [0]=xq_out, [1]=xk_out (positional; the numel check + // below guards a swap, since NH != NKV under GQA). + if (outputs.size() < 2 || !outputs[0].isTensor() || !outputs[1].isTensor()) { + printf("FAIL: expected 2 tensor outputs, got %zu\n", outputs.size()); + return false; + } + const auto& xq_t = outputs[0].toTensor(); + const auto& xk_t = outputs[1].toTensor(); + if (xq_t.numel() != xq_numel || xk_t.numel() != xk_numel) { + printf( + "FAIL: output shapes [%zu,%zu] != expected [%d,%d]\n", + (size_t)xq_t.numel(), + (size_t)xk_t.numel(), + xq_numel, + xk_numel); + return false; + } + const float* xq_out = xq_t.const_data_ptr(); + const float* xk_out = xk_t.const_data_ptr(); + + std::vector gq = load_golden(xq_golden_path, xq_numel); + std::vector gk = load_golden(xk_golden_path, xk_numel); + if (gq.empty() || gk.empty()) { + printf( + "FAIL: could not load goldens %s / %s\n", + xq_golden_path.c_str(), + xk_golden_path.c_str()); + return false; + } + + // Per-element abs-OR-rel on xq and xk (shared helper, defined above). + float maq = 0.0f, mrq = 0.0f, mak = 0.0f, mrk = 0.0f; + const bool pass_q = + quant_within_tol(xq_out, gq.data(), xq_numel, 1e-3f, 1e-3f, &maq, &mrq); + const bool pass_k = + quant_within_tol(xk_out, gk.data(), xk_numel, 1e-3f, 1e-3f, &mak, &mrk); + const float max_abs_err = std::max(maq, mak); + const float max_rel_err = std::max(mrq, mrk); + + printf( + "Max abs error: %e Max rel error: %e (checked %d elements)\n", + max_abs_err, + max_rel_err, + xq_numel + xk_numel); + if (!(pass_q && pass_k)) { + printf("FAIL: apply_rotary_emb exceeds tolerance 1e-3 (abs AND rel)\n"); + return false; + } + printf("PASS: apply_rotary_emb test\n"); + return true; +} + // Reconstruct _ramp_input bit-for-bit, run the op, compare to the fp64 golden. static bool test_q4gsw_config( const Q4gswConfig& cfg, @@ -1472,6 +1583,37 @@ int main(int argc, char** argv) { 2048}, }; + // apply_rotary_emb on-GPU configs: multi + decode (env-gated, + // run-if-present). + struct RopeConfig { + const char* name; + const char* model_env; + const char* xq_env; + const char* xk_env; + int S; + int NH; + int NKV; + int HD; + }; + const RopeConfig rope_configs[] = { + {"multi", + "WEBGPU_TEST_ROPE_MODEL", + "WEBGPU_TEST_ROPE_XQ_GOLDEN", + "WEBGPU_TEST_ROPE_XK_GOLDEN", + 5, + 8, + 2, + 64}, + {"decode", + "WEBGPU_TEST_ROPE_DECODE_MODEL", + "WEBGPU_TEST_ROPE_DECODE_XQ_GOLDEN", + "WEBGPU_TEST_ROPE_DECODE_XK_GOLDEN", + 1, + 32, + 8, + 64}, + }; + // SDPA sweep: configs self-discover their sdpa_.pte/.golden.bin under // this directory (default "" = the embedded-file root / cwd). Set // WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/). @@ -1528,6 +1670,15 @@ int main(int argc, char** argv) { } } + for (const auto& c : rope_configs) { + const char* m = std::getenv(c.model_env); + const char* xq = std::getenv(c.xq_env); + const char* xk = std::getenv(c.xk_env); + if (m && xq && xk && *m && *xq && *xk) { + ok = test_rope(m, xq, xk, c.S, c.NH, c.NKV, c.HD, c.name) && ok; + } + } + bool sdpa_ran = false; bool sdpa_ok = test_sdpa_sweep(sdpa_dir, &sdpa_ran); if (sdpa_ran) { From eb80092e0f4b6728b3224c9a552cde46a3d3b1ea Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:08 -0700 Subject: [PATCH 5/6] [ExecuTorch][WebGPU] Add et_vk.prepack (constant-tensor packing) for E2E weight loading Pull Request resolved: https://github.com/pytorch/executorch/pull/20265 Adds the WebGPU backend handler for `et_vk.prepack.default`, the node the VulkanPartitioner wraps around every constant feeding a delegated op so the constant is materialized into its dedicated GPU buffer before inference. For the WebGPU backend's buffer-flat/fp32 model, prepack is an identity layout (same dims, dtype, and bytes), so the handler runs no compute shader: it validates that `src` and `out` match (dims, `elem_size`, `nbytes`, non-null buffers; every check throws fail-loud) and records a one-time `src`->`out` buffer-to-buffer copy via the new `WebGPUGraph::add_prepack_copy`. The recorded copies run once in a new `build()` Phase 4 (after the op-dispatch chain is recorded), mirroring the Vulkan delegate's separate `prepack()` init phase (distinct from per-inference `execute()`). Ordering is guaranteed by the WebGPU queue -- the prepack submit precedes the first `execute()` submit on the same queue, so the copied data is visible without an explicit device poll (Dawn has no `wgpuDevicePoll`, and the backend relies on queue ordering plus the output-map wait elsewhere). `src.elem_size` is the `WebGPUTensor` field added by the embedding op lower in this stack, so prepack stacks above it. ghstack-source-id: 395549289 @exported-using-ghexport Differential Revision: [D108428754](https://our.internmc.facebook.com/intern/diff/D108428754/) --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUGraph.cpp | 174 +++++++++++++----- backends/webgpu/runtime/WebGPUGraph.h | 21 +++ .../webgpu/runtime/ops/prepack/Prepack.cpp | 55 ++++++ 4 files changed, 202 insertions(+), 49 deletions(-) create mode 100644 backends/webgpu/runtime/ops/prepack/Prepack.cpp diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 8bf1674d872..f7cd85f9758 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -41,6 +41,7 @@ set(WEBGPU_SRCS runtime/ops/mul/BinaryOp.cpp runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp runtime/ops/rope/RotaryEmbedding.cpp + runtime/ops/prepack/Prepack.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index 65aaaf6c681..b7fb4313400 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu { namespace { +// Op name the AOT exporter emits for a prepacked constant (must match the +// serialized schema); compared in the prepack pre-scan below. +constexpr const char* kPrepackOpName = "et_vk.prepack.default"; + size_t vk_datatype_size(vkgraph::VkDataType dtype) { switch (dtype) { case vkgraph::VkDataType::BOOL: @@ -230,6 +234,10 @@ void WebGPUGraph::build( const auto* graph = vkgraph::GetVkGraph(flatbuffer_data); + // .pte byte sources for prepack-time constant materialization (build-only). + constant_data_ = constant_data; + named_data_map_ = named_data_map; + // Phase 1: Create all values const auto* values = graph->values(); const int num_vals = values ? values->size() : 0; @@ -241,6 +249,41 @@ void WebGPUGraph::build( bools_.resize(num_vals, false); value_lists_.resize(num_vals); + // Pre-scan the op chain: a constant may be DEFERRED (no eager GPU buffer; the + // prepack node materializes it once) only if it is a prepack source AND never + // a direct arg of a non-prepack op. ValueList args are expanded so a constant + // reached through a list still counts as a direct use. + std::unordered_set prepack_src_ids; + std::unordered_set direct_use_ids; + const auto* chain_prescan = graph->chain(); + if (chain_prescan) { + for (unsigned ci = 0; ci < chain_prescan->size(); ci++) { + const auto* oc = chain_prescan->Get(ci); + const bool is_prepack = oc->name()->str() == kPrepackOpName; + const auto* a = oc->args(); + if (!a) { + continue; + } + for (unsigned j = 0; j < a->size(); j++) { + int id = static_cast(a->Get(j)); + if (is_prepack && j == 0) { + prepack_src_ids.insert(id); + } else if (!is_prepack) { + direct_use_ids.insert(id); + const auto* v = values ? values->Get(id) : nullptr; + if (v && v->value_type() == vkgraph::GraphTypes::ValueList) { + const auto* items = v->value_as_ValueList()->items(); + if (items) { + for (unsigned k = 0; k < items->size(); k++) { + direct_use_ids.insert(static_cast(items->Get(k))); + } + } + } + } + } + } + } + for (int i = 0; i < num_vals; i++) { const auto* val = values->Get(i); if (!val || val->value_type() == vkgraph::GraphTypes::NONE) { @@ -269,60 +312,51 @@ void WebGPUGraph::build( int constant_id = vk_tensor->constant_id(); int mem_obj_id = vk_tensor->mem_obj_id(); - // Constants always get dedicated buffers regardless of mem_obj_id + // Constants are dedicated. Every constant is recorded as a + // ConstantSource and materialized via materialize_constant (one + // CPU->GPU write); a constant consumed ONLY via prepack is deferred + // (no eager buffer -- its prepack node performs that one write). if (constant_id >= 0 || mem_obj_id < 0) { tensor_mem_obj_ids_[i] = -1; - WGPUBufferDescriptor buf_desc = {}; - buf_desc.size = std::max(tensor.nbytes, size_t(4)); - buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | - WGPUBufferUsage_CopySrc; - buf_desc.mappedAtCreation = false; - tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc); - - if (constant_id >= 0 && constant_data && tensor.nbytes > 0) { + + if (constant_id >= 0) { const auto* constants = graph->constants(); - if (constants && - constant_id < static_cast(constants->size())) { - const auto* vk_bytes = constants->Get(constant_id); - if (vk_bytes->offset() != UINT64_MAX) { - const uint8_t* src = constant_data + vk_bytes->offset(); - wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, src, tensor.nbytes); - } else if ( - vk_bytes->named_key() != nullptr && - named_data_map != nullptr) { - // Constant stored in the PTE named-data map. - auto buf = - named_data_map->get_data(vk_bytes->named_key()->c_str()); - if (!buf.ok()) { - throw std::runtime_error( - std::string("WebGPU: named constant '") + - vk_bytes->named_key()->c_str() + - "' not found in NamedDataMap"); - } - if (buf->size() < tensor.nbytes) { - throw std::runtime_error( - std::string("WebGPU: named constant '") + - vk_bytes->named_key()->c_str() + "' undersized: have " + - std::to_string(buf->size()) + " bytes, need " + - std::to_string(tensor.nbytes)); - } - wgpuQueueWriteBuffer( - queue_, tensor.buffer, 0, buf->data(), tensor.nbytes); - buf->Free(); - } else { - throw std::runtime_error( - "WebGPU: constant has no inline offset and no named-data key"); - } - } else { + if (!constants || + constant_id >= static_cast(constants->size())) { throw std::runtime_error( "WebGPU: constant_id set but the constants table is missing " "or the id is out of range"); } - } else if (constant_id >= 0 && tensor.nbytes > 0) { - // constant_id set but constant_data null -> fail loud. - throw std::runtime_error( - "WebGPU: constant_id set but constant_data is null"); + const auto* vk_bytes = constants->Get(constant_id); + ConstantSource cs; + cs.nbytes = tensor.nbytes; + if (vk_bytes->offset() != UINT64_MAX) { + cs.inline_offset = vk_bytes->offset(); + } else if (vk_bytes->named_key() != nullptr) { + cs.named_key = vk_bytes->named_key()->str(); + } else { + throw std::runtime_error( + "WebGPU: constant has no inline offset and no named-data key"); + } + constant_sources_[i] = std::move(cs); + } + + // Defer constants consumed solely via prepack: skip the eager buffer. + const bool defer = constant_id >= 0 && + prepack_src_ids.count(i) != 0 && direct_use_ids.count(i) == 0; + if (!defer) { + WGPUBufferDescriptor buf_desc = {}; + buf_desc.size = std::max(tensor.nbytes, size_t(4)); + buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst | + WGPUBufferUsage_CopySrc; + buf_desc.mappedAtCreation = false; + tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc); + + // Same single CPU->GPU write the prepack node uses (no + // duplication). + if (constant_id >= 0) { + materialize_constant(i, tensor.buffer); + } } } else { // Shared buffer: track required size, defer allocation to pass 2 @@ -458,6 +492,47 @@ void WebGPUGraph::build( webgpu_operator_registry().get_op_fn(op_name)(*this, args); } } + + // Prepack nodes (Phase 3) materialized their constants directly into the + // consumer buffers via materialize_constant; no separate copy pass needed. + // The .pte bytes are freed right after build() returns (WebGPUBackend + // processed->Free()), so clear the build-only source pointers. + constant_data_ = nullptr; + named_data_map_ = nullptr; +} + +void WebGPUGraph::materialize_constant(int const_value_id, WGPUBuffer dst) { + auto it = constant_sources_.find(const_value_id); + if (it == constant_sources_.end()) { + throw std::runtime_error( + "WebGPU: no source recorded for constant id " + + std::to_string(const_value_id)); + } + const ConstantSource& cs = it->second; + if (cs.nbytes == 0) { + return; + } + if (cs.inline_offset != UINT64_MAX) { + if (constant_data_ == nullptr) { + throw std::runtime_error("WebGPU: inline constant data is null"); + } + wgpuQueueWriteBuffer( + queue_, dst, 0, constant_data_ + cs.inline_offset, cs.nbytes); + } else if (!cs.named_key.empty() && named_data_map_ != nullptr) { + auto buf = named_data_map_->get_data(cs.named_key.c_str()); + if (!buf.ok()) { + throw std::runtime_error( + "WebGPU: named constant '" + cs.named_key + "' not found"); + } + if (buf->size() < cs.nbytes) { + throw std::runtime_error( + "WebGPU: named constant '" + cs.named_key + "' undersized"); + } + wgpuQueueWriteBuffer(queue_, dst, 0, buf->data(), cs.nbytes); + buf->Free(); + } else { + throw std::runtime_error("WebGPU: constant has no source"); + } } WGPUShaderModule WebGPUGraph::get_or_create_shader( @@ -780,10 +855,11 @@ WebGPUMemoryStats WebGPUGraph::memory_stats() const { for (size_t i = 0; i < value_types_.size(); i++) { if (value_types_[i] == ValueType::Tensor && tensors_[i].nbytes > 0) { stats.num_tensors++; - // Shared tensors are tracked via shared_buffer_sizes_ + // Shared tensors are tracked via shared_buffer_sizes_; a deferred + // prepack-routed constant has no buffer (no GPU memory) -> not counted. bool is_shared = i < tensor_mem_obj_ids_.size() && tensor_mem_obj_ids_[i] >= 0; - if (!is_shared) { + if (!is_shared && tensors_[i].buffer != nullptr) { stats.unshared_tensor_buffer_bytes += tensors_[i].nbytes; } } diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index a914c8710ce..3572f751a06 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -50,6 +50,15 @@ struct OutputCopy { size_t nbytes = 0; }; +// CPU-side record for a prepack-routed constant; mirrors Vulkan's TensorRef +// (sizes + a data reference, not a live GPU tensor). The prepack node is the +// sole materialization, so the constant needs no eager GPU buffer. +struct ConstantSource { + uint64_t inline_offset = UINT64_MAX; // offset into constant_data_; else key + std::string named_key; // non-empty => fetch from named_data_map_ + size_t nbytes = 0; +}; + struct ExecuteConfig { size_t chunk_size = 0; size_t initial_chunk_size = 0; @@ -180,6 +189,11 @@ class WebGPUGraph { dispatches_.push_back(dispatch); } + // Materialize a recorded prepack-routed constant into dst via one CPU->GPU + // transfer. Build-time only (the .pte bytes are freed after build()). + // Mirrors Vulkan prepack_standard. + void materialize_constant(int const_value_id, WGPUBuffer dst); + void add_uniform_buffer_bytes(size_t bytes) { uniform_buffer_bytes_ += bytes; } @@ -286,6 +300,13 @@ class WebGPUGraph { std::vector dispatches_; + // Prepack-routed constant sources (offset/named-key + size); the prepack node + // materializes these once. constant_data_/named_data_map_ point at the .pte + // bytes and are valid only during build(). + const uint8_t* constant_data_ = nullptr; + const executorch::runtime::NamedDataMap* named_data_map_ = nullptr; + std::unordered_map constant_sources_; + ExecuteConfig execute_config_; // Caches for reusing GPU objects across dispatches. diff --git a/backends/webgpu/runtime/ops/prepack/Prepack.cpp b/backends/webgpu/runtime/ops/prepack/Prepack.cpp new file mode 100644 index 00000000000..71414f91787 --- /dev/null +++ b/backends/webgpu/runtime/ops/prepack/Prepack.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +namespace executorch::backends::webgpu { + +namespace { + +// Materialize a constant into the prepack-output buffer via one CPU->GPU write. +void prepack_impl(WebGPUGraph& graph, const std::vector& args) { + // et_vk.prepack.default args: [src (constant), out]. + if (args.size() != 2) { + throw std::runtime_error("WebGPU prepack: expected 2 args (src, out)"); + } + const auto& src = graph.get_tensor(args.at(0)); + const auto& out = graph.get_tensor(args.at(1)); + + if (src.dims != out.dims) { + throw std::runtime_error("WebGPU prepack: src/out shape mismatch"); + } + if (src.elem_size != out.elem_size) { + throw std::runtime_error( + "WebGPU prepack: src/out dtype mismatch (cast unsupported)"); + } + if (src.nbytes != out.nbytes) { + throw std::runtime_error("WebGPU prepack: src/out byte-size mismatch"); + } + if (out.buffer == nullptr) { + throw std::runtime_error("WebGPU prepack: null out buffer binding"); + } + + // Sole materialization: write the .pte bytes once, straight into the + // consumer's buffer (no eager src buffer, no buffer->buffer copy). + // Correctness of this write-once relies on `out` being a dedicated buffer + // (the partitioner gives prepack outputs mem_obj_id=-1, so it is never + // memory-plan aliased with a transient that execute() would later overwrite). + graph.materialize_constant(args.at(0), out.buffer); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(et_vk.prepack.default, prepack_impl); +} + +} // namespace executorch::backends::webgpu From 46265613f0663b8a1cadb8a91b4edf66d46f18e8 Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Sun, 21 Jun 2026 22:02:09 -0700 Subject: [PATCH 6/6] [ExecuTorch][WebGPU] et_vk.prepack test suite (export + native golden) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/20292 Test suite for the `et_vk.prepack` constant-materialization op, split into its own diff (op below, tests above) per the per-op test-split convention. The prepack op is how a serialized constant becomes a GPU tensor: the constant arrives as a CPU-side reference (sizes + a pointer into the .pte bytes), and the prepack node is the sole materialization — one CPU->GPU transfer straight into the consumer's buffer. The model `M(x) = x + w` (w a constant) routes `w` through a prepack node, so the delegate must run the materialization for the output to equal `x + w` rather than `x + 0`. ghstack-source-id: 395555139 @exported-using-ghexport Differential Revision: [D108678631](https://our.internmc.facebook.com/intern/diff/D108678631/) --- .../webgpu/scripts/test_webgpu_native_ci.sh | 19 +++ backends/webgpu/test/ops/prepack/__init__.py | 5 + .../webgpu/test/ops/prepack/test_prepack.py | 142 ++++++++++++++++++ backends/webgpu/test/test_webgpu_native.cpp | 110 ++++++++++++++ 4 files changed, 276 insertions(+) create mode 100644 backends/webgpu/test/ops/prepack/__init__.py create mode 100644 backends/webgpu/test/ops/prepack/test_prepack.py diff --git a/backends/webgpu/scripts/test_webgpu_native_ci.sh b/backends/webgpu/scripts/test_webgpu_native_ci.sh index 100e48dfbfd..84b5349ef2d 100644 --- a/backends/webgpu/scripts/test_webgpu_native_ci.sh +++ b/backends/webgpu/scripts/test_webgpu_native_ci.sh @@ -57,6 +57,12 @@ ROPE_XK_GOLDEN="/tmp/webgpu_rope_xk_golden.bin" ROPE_DECODE_MODEL="/tmp/webgpu_rope_decode.pte" ROPE_DECODE_XQ_GOLDEN="/tmp/webgpu_rope_decode_xq_golden.bin" ROPE_DECODE_XK_GOLDEN="/tmp/webgpu_rope_decode_xk_golden.bin" +PREPACK_MODEL="/tmp/webgpu_prepack.pte" +PREPACK_GOLDEN="/tmp/webgpu_prepack_golden.bin" +PREPACK2_MODEL="/tmp/webgpu_prepack_two_const.pte" +PREPACK2_GOLDEN="/tmp/webgpu_prepack_two_const_golden.bin" +PREPACK_TIED_MODEL="/tmp/webgpu_prepack_tied_const.pte" +PREPACK_TIED_GOLDEN="/tmp/webgpu_prepack_tied_const_golden.bin" $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.quantized_linear.test_quantized_linear import export_all_quantized_linear_models @@ -75,6 +81,13 @@ export_rope_model('${ROPE_MODEL}', '${ROPE_XQ_GOLDEN}', '${ROPE_XK_GOLDEN}') export_rope_model('${ROPE_DECODE_MODEL}', '${ROPE_DECODE_XQ_GOLDEN}', '${ROPE_DECODE_XK_GOLDEN}', 'decode') " || echo "WARN: rope export failed; apply_rotary_emb configs will FAIL in webgpu_native_test" +$PYTHON_EXECUTABLE -c " +from executorch.backends.webgpu.test.ops.prepack.test_prepack import export_prepack_model, export_prepack_two_const_model, export_prepack_tied_const_model +export_prepack_model('${PREPACK_MODEL}', '${PREPACK_GOLDEN}') +export_prepack_two_const_model('${PREPACK2_MODEL}', '${PREPACK2_GOLDEN}') +export_prepack_tied_const_model('${PREPACK_TIED_MODEL}', '${PREPACK_TIED_GOLDEN}') +" || echo "WARN: prepack export failed; prepack configs will FAIL in webgpu_native_test" + $PYTHON_EXECUTABLE -c " from executorch.backends.webgpu.test.ops.dispatch_order.test_dispatch_order import export_dispatch_order_cases export_dispatch_order_cases('${DISPATCH_ORDER_DIR}') @@ -172,6 +185,12 @@ if [[ -x "${BIN_DIR}/webgpu_native_test" ]] && WEBGPU_TEST_ROPE_DECODE_MODEL="${ROPE_DECODE_MODEL}" \ WEBGPU_TEST_ROPE_DECODE_XQ_GOLDEN="${ROPE_DECODE_XQ_GOLDEN}" \ WEBGPU_TEST_ROPE_DECODE_XK_GOLDEN="${ROPE_DECODE_XK_GOLDEN}" \ + WEBGPU_TEST_PREPACK_MODEL="${PREPACK_MODEL}" \ + WEBGPU_TEST_PREPACK_GOLDEN="${PREPACK_GOLDEN}" \ + WEBGPU_TEST_PREPACK2_MODEL="${PREPACK2_MODEL}" \ + WEBGPU_TEST_PREPACK2_GOLDEN="${PREPACK2_GOLDEN}" \ + WEBGPU_TEST_PREPACK_TIED_MODEL="${PREPACK_TIED_MODEL}" \ + WEBGPU_TEST_PREPACK_TIED_GOLDEN="${PREPACK_TIED_GOLDEN}" \ "${BIN_DIR}/webgpu_native_test" else echo "(skipping webgpu_native_test: executorch wheel absent — exports did not run)" diff --git a/backends/webgpu/test/ops/prepack/__init__.py b/backends/webgpu/test/ops/prepack/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/webgpu/test/ops/prepack/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/webgpu/test/ops/prepack/test_prepack.py b/backends/webgpu/test/ops/prepack/test_prepack.py new file mode 100644 index 00000000000..0769177143f --- /dev/null +++ b/backends/webgpu/test/ops/prepack/test_prepack.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Constant-tensor prepack (`et_vk.prepack`) export + golden for the WebGPU +backend. + +The VulkanPartitioner wraps every constant feeding a delegated op in an +`et_vk.prepack.default` node that materializes the constant into a GPU buffer at +init. Model `M(x) = x + w` (w a constant) routes `w` through prepack, so the +delegate must run the prepack copy for the output to equal `x + w` rather than +`x + 0 = x`. The input is a deterministic /16 ramp so the native binary +reconstructs it bit-for-bit; the torch-computed golden is written for the native +binary to compare (it has no ATen). +""" + +import unittest + +import executorch.backends.vulkan.custom_ops_lib # noqa: F401 + +import torch +from executorch.backends.vulkan import VulkanPartitioner +from executorch.exir import to_edge_transform_and_lower + +# 4x4 constant weight, small enough to dump and reason about by hand. +N = 4 + + +class _AddConst(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + # arange weight: non-zero everywhere so an unrun prepack (out = x + 0 = x) + # is unambiguously distinguishable from a correct one (out = x + w). + self.w = torch.nn.Parameter( + torch.arange(N * N, dtype=torch.float32).reshape(N, N) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.w + + +class _AddTwoConst(torch.nn.Module): + # Two constants => two prepack nodes (the multi-copy path E2E Llama needs); + # add-only so it stays delegated with just this stack's registered ops. + def __init__(self) -> None: + super().__init__() + self.w1 = torch.nn.Parameter( + torch.arange(N * N, dtype=torch.float32).reshape(N, N) + ) + self.w2 = torch.nn.Parameter( + torch.arange(N * N, dtype=torch.float32).reshape(N, N) * 0.5 - 3.0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.w1 + self.w2 + + +class _AddTiedConst(torch.nn.Module): + # Two BYTE-IDENTICAL constants => two prepack nodes sharing ONE SHA256 + # named-data key (tied/duplicate weights). Exercises the prepack handler + # materializing the same key twice (independent get_data + Free per call). + def __init__(self) -> None: + super().__init__() + self.w1 = torch.nn.Parameter( + torch.arange(N * N, dtype=torch.float32).reshape(N, N) + ) + self.w2 = torch.nn.Parameter( + torch.arange(N * N, dtype=torch.float32).reshape(N, N) + ) + # Pin the tied premise; the dedup to one key is assumed, not asserted. + assert torch.equal(self.w1, self.w2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.w1 + self.w2 + + +def _inputs() -> tuple[torch.Tensor]: + # ((i % 13) - 6) / 16: exact in fp32, matches test_webgpu_native.cpp. + idx = torch.arange(N * N, dtype=torch.int64) + x = (((idx % 13) - 6).to(torch.float32) / 16.0).reshape(N, N) + return (x,) + + +def _export(model, inputs): + ep = torch.export.export(model.eval(), inputs) + return to_edge_transform_and_lower( + ep, partitioner=[VulkanPartitioner()] + ).to_executorch() + + +class TestPrepack(unittest.TestCase): + def test_export_delegates(self) -> None: + # Each model must fully delegate -- every constant wrapped in a prepack + # node inside a VulkanBackend delegate (single, multi-const, tied). + for name, model in ( + ("x + w", _AddConst()), + ("x + w1 + w2", _AddTwoConst()), + ("x + w + w (tied)", _AddTiedConst()), + ): + with self.subTest(model=name): + et = _export(model, _inputs()) + found = any( + d.id == "VulkanBackend" + for plan in et.executorch_program.execution_plan + for d in plan.delegates + ) + self.assertTrue(found, f"Expected a VulkanBackend delegate: {name}") + + +def _write(model, pte_path: str, golden_path: str) -> None: + (x,) = _inputs() + golden = model.eval()(x) + et = _export(model, (x,)) + with open(pte_path, "wb") as f: + f.write(et.buffer) + golden.detach().numpy().astype(" None: + """Write the x + w .pte + torch golden (raw LE fp32). One prepacked constant. + The input is a /16 ramp reconstructed in the native test.""" + _write(_AddConst(), pte_path, golden_path) + + +def export_prepack_two_const_model(pte_path: str, golden_path: str) -> None: + """Write the x + w1 + w2 .pte + golden. Two prepacked constants, exercising + the multi-copy path.""" + _write(_AddTwoConst(), pte_path, golden_path) + + +def export_prepack_tied_const_model(pte_path: str, golden_path: str) -> None: + """Write the x + w1 + w2 .pte + golden where w1 and w2 are BYTE-IDENTICAL, + so they share one named-data key -> two prepack nodes materialize the same + key (verifies per-call buffer ownership / no double-free on tied weights).""" + _write(_AddTiedConst(), pte_path, golden_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 6a607bcab17..ad7ad2f2fc2 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -536,6 +536,74 @@ static bool test_rope( return true; } +static bool test_prepack( + const std::string& model_path, + const std::string& golden_path, + const std::string& label = "x + const w") { + // et_vk.prepack copy vs golden; unrun copy leaves zeros. See test_prepack.py. + constexpr int n = 4; + constexpr int numel = n * n; + printf("\n--- Test: prepack (%s, %dx%d) ---\n", label.c_str(), n, n); + + Module module(model_path); + auto err = module.load_forward(); + if (err != Error::Ok) { + printf("FAIL: could not load forward method (error %d)\n", (int)err); + return false; + } + printf("Model loaded: %s\n", model_path.c_str()); + + std::vector golden = load_golden(golden_path, numel); + if (golden.empty()) { + printf("FAIL: could not load golden %s\n", golden_path.c_str()); + return false; + } + + // ((i % 13) - 6) / 16: exact in fp32, matches test_prepack.py::_inputs. + std::vector x_data(numel); + for (int i = 0; i < numel; i++) { + x_data[i] = static_cast((i % 13) - 6) / 16.0f; + } + auto x = make_tensor_ptr({n, n}, std::vector(x_data)); + + auto result = module.forward({EValue(x)}); + if (!result.ok()) { + printf("FAIL: forward failed (error %d)\n", (int)result.error()); + return false; + } + const auto& outputs = result.get(); + if (outputs.empty() || !outputs[0].isTensor()) { + printf("FAIL: no tensor output\n"); + return false; + } + const auto& out_tensor = outputs[0].toTensor(); + if (out_tensor.numel() != numel) { + printf( + "FAIL: output numel %zu != expected %d\n", + (size_t)out_tensor.numel(), + numel); + return false; + } + const float* out_data = out_tensor.const_data_ptr(); + + float max_abs_err = 0.0f, max_rel_err = 0.0f; + // Per-element abs-OR-rel (quant_within_tol): a global rel gate spuriously + // fails near-zero outputs where rel error explodes. + const bool within = quant_within_tol( + out_data, golden.data(), numel, 1e-3f, 1e-3f, &max_abs_err, &max_rel_err); + printf( + "Max abs error: %e Max rel error: %e (checked %d elements)\n", + max_abs_err, + max_rel_err, + numel); + if (!within) { + printf("FAIL: prepack exceeds tolerance 1e-3\n"); + return false; + } + printf("PASS: prepack test\n"); + return true; +} + // Reconstruct _ramp_input bit-for-bit, run the op, compare to the fp64 golden. static bool test_q4gsw_config( const Q4gswConfig& cfg, @@ -1614,6 +1682,30 @@ int main(int argc, char** argv) { 64}, }; + std::string prepack_model_path, prepack_golden_path; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_MODEL")) { + prepack_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_GOLDEN")) { + prepack_golden_path = env; + } + + std::string prepack2_model_path, prepack2_golden_path; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_MODEL")) { + prepack2_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK2_GOLDEN")) { + prepack2_golden_path = env; + } + + std::string prepack_tied_model_path, prepack_tied_golden_path; + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_MODEL")) { + prepack_tied_model_path = env; + } + if (const char* env = std::getenv("WEBGPU_TEST_PREPACK_TIED_GOLDEN")) { + prepack_tied_golden_path = env; + } + // SDPA sweep: configs self-discover their sdpa_.pte/.golden.bin under // this directory (default "" = the embedded-file root / cwd). Set // WEBGPU_TEST_SDPA_DIR to point at the exported .pte directory (e.g. /tmp/). @@ -1679,6 +1771,24 @@ int main(int argc, char** argv) { } } + if (!prepack_model_path.empty() && !prepack_golden_path.empty()) { + ok = test_prepack(prepack_model_path, prepack_golden_path) && ok; + } + + if (!prepack2_model_path.empty() && !prepack2_golden_path.empty()) { + ok = test_prepack( + prepack2_model_path, prepack2_golden_path, "x + w1 + w2") && + ok; + } + + if (!prepack_tied_model_path.empty() && !prepack_tied_golden_path.empty()) { + ok = test_prepack( + prepack_tied_model_path, + prepack_tied_golden_path, + "x + w + w (tied weights, shared key)") && + ok; + } + bool sdpa_ran = false; bool sdpa_ok = test_sdpa_sweep(sdpa_dir, &sdpa_ran); if (sdpa_ran) {