From 4fca5ee5b15a5b444e5eee38e8191b17abcbce5e Mon Sep 17 00:00:00 2001 From: Julian Ng-Thow-Hing Date: Thu, 18 Jun 2026 14:35:55 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/webgpu/CMakeLists.txt | 1 + backends/webgpu/runtime/WebGPUGraph.cpp | 11 ++ backends/webgpu/runtime/WebGPUGraph.h | 8 +- backends/webgpu/runtime/ops/cat/Cat.cpp | 218 +++++++++++++++++++++ backends/webgpu/runtime/ops/cat/cat.wgsl | 41 ++++ backends/webgpu/runtime/ops/cat/cat_wgsl.h | 65 ++++++ 6 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 backends/webgpu/runtime/ops/cat/Cat.cpp create mode 100644 backends/webgpu/runtime/ops/cat/cat.wgsl create mode 100644 backends/webgpu/runtime/ops/cat/cat_wgsl.h diff --git a/backends/webgpu/CMakeLists.txt b/backends/webgpu/CMakeLists.txt index 5289e8c8d17..ef94b629dcf 100644 --- a/backends/webgpu/CMakeLists.txt +++ b/backends/webgpu/CMakeLists.txt @@ -46,6 +46,7 @@ set(WEBGPU_SRCS runtime/ops/unsqueeze/Unsqueeze.cpp runtime/ops/slice/Slice.cpp runtime/ops/permute/Permute.cpp + runtime/ops/cat/Cat.cpp ) add_library(webgpu_backend ${WEBGPU_SRCS}) diff --git a/backends/webgpu/runtime/WebGPUGraph.cpp b/backends/webgpu/runtime/WebGPUGraph.cpp index bbee0df766e..e64fbe6db44 100644 --- a/backends/webgpu/runtime/WebGPUGraph.cpp +++ b/backends/webgpu/runtime/WebGPUGraph.cpp @@ -225,6 +225,7 @@ void WebGPUGraph::build( tensor_mem_obj_ids_.resize(num_vals, -1); ints_.resize(num_vals, 0); int_lists_.resize(num_vals); + value_lists_.resize(num_vals); doubles_.resize(num_vals, 0.0); bools_.resize(num_vals, false); @@ -326,6 +327,16 @@ void WebGPUGraph::build( } break; } + case vkgraph::GraphTypes::ValueList: { + value_types_[i] = ValueType::ValueList; + const auto* items = val->value_as_ValueList()->items(); + if (items) { + for (unsigned j = 0; j < items->size(); j++) { + value_lists_[i].push_back(static_cast(items->Get(j))); + } + } + break; + } case vkgraph::GraphTypes::Double: { value_types_[i] = ValueType::Double; doubles_[i] = val->value_as_Double()->double_val(); diff --git a/backends/webgpu/runtime/WebGPUGraph.h b/backends/webgpu/runtime/WebGPUGraph.h index cc94c4f8c46..0d9f59094c3 100644 --- a/backends/webgpu/runtime/WebGPUGraph.h +++ b/backends/webgpu/runtime/WebGPUGraph.h @@ -111,6 +111,10 @@ class WebGPUGraph { const std::vector& get_int_list(int id) const { return int_lists_[id]; } + // Member value ids of a serialized ValueList (op multi-output list). + const std::vector& get_value_list(int id) const { + return value_lists_[id]; + } bool get_bool(int id) const { return bools_[id]; } @@ -219,7 +223,8 @@ class WebGPUGraph { Null, String, SymInt, - IntList + IntList, + ValueList }; ValueType get_value_type(int id) const { @@ -237,6 +242,7 @@ class WebGPUGraph { std::vector tensors_; std::vector ints_; std::vector> int_lists_; + std::vector> value_lists_; std::vector doubles_; std::vector bools_; diff --git a/backends/webgpu/runtime/ops/cat/Cat.cpp b/backends/webgpu/runtime/ops/cat/Cat.cpp new file mode 100644 index 00000000000..0cfb857745c --- /dev/null +++ b/backends/webgpu/runtime/ops/cat/Cat.cpp @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace executorch::backends::webgpu { + +namespace { + +struct CatParams { + uint32_t concat_dim; + uint32_t off_k; + uint32_t _pad[2]; +}; +static_assert( + sizeof(CatParams) == 16, + "CatParams must match the WGSL Params uniform (16-byte aligned)"); + +// cat: 1 dispatch/input -> disjoint out slab at host off_k (Vulkan concat). +void cat_impl(WebGPUGraph& graph, const std::vector& args) { + // args: [tensors (ValueList), dim, out]. + const int list_id = args.at(0); + const int out_id = args.at(args.size() - 1); + + if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) { + throw std::runtime_error("cat: tensors arg is not a ValueList"); + } + if (graph.get_value_type(args.at(1)) != WebGPUGraph::ValueType::Int) { + throw std::runtime_error("cat: dim arg is not a static Int"); + } + if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("cat: out arg is not a tensor"); + } + + WGPUDevice device = graph.device(); + const std::vector& ids = graph.get_value_list(list_id); + if (ids.empty()) { + throw std::runtime_error("cat: empty input list"); + } + + const auto& out_tensor = graph.get_tensor(out_id); + const int ndim = static_cast(out_tensor.dims.size()); + + int64_t dim = graph.get_int(args.at(1)); + if (dim < 0) { + dim += ndim; + } + if (dim < 0 || dim >= ndim) { + throw std::runtime_error("cat: dim out of range"); + } + + // Workgroup size is invariant across inputs: clamp once, share the constant. + uint32_t wg_size = utils::clamp_workgroup_size(device, kCatWorkgroupSizeX); + + // Validate + cache input meta/wgc BEFORE any GPU alloc (no leak on throw). + std::vector in_metas(ids.size()); + std::vector wg_counts(ids.size()); + int64_t concat_sum = 0; + for (size_t k = 0; k < ids.size(); k++) { + const int id = ids[k]; + if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) { + throw std::runtime_error("cat: input list element is not a tensor"); + } + const auto& in_tensor = graph.get_tensor(id); + if (static_cast(in_tensor.dims.size()) != ndim) { + throw std::runtime_error("cat: input rank != output rank"); + } + for (int d = 0; d < ndim; d++) { + if (d != dim && in_tensor.dims[d] != out_tensor.dims[d]) { + throw std::runtime_error("cat: non-concat dim size mismatch"); + } + } + fill_tensor_meta(in_tensor, &in_metas[k]); + if (in_tensor.nbytes != + static_cast(in_metas[k].numel) * sizeof(float)) { + throw std::runtime_error("cat: non-fp32 input (nbytes != numel * 4)"); + } + wg_counts[k] = utils::compute_1d_workgroup_count( + device, in_metas[k].numel, wg_size, "cat"); + concat_sum += in_tensor.dims[dim]; + } + if (concat_sum != out_tensor.dims[dim]) { + throw std::runtime_error("cat: concat dim sizes do not sum to output"); + } + + TensorMeta out_meta; + fill_tensor_meta(out_tensor, &out_meta); + if (out_tensor.nbytes != + static_cast(out_meta.numel) * sizeof(float)) { + throw std::runtime_error("cat: non-fp32 output (nbytes != numel * 4)"); + } + + WGPUBuffer out_meta_buf = + utils::make_uniform(device, &out_meta, sizeof(TensorMeta)); + graph.add_uniform_buffer_bytes(sizeof(TensorMeta)); + + WGPUConstantEntry wg_size_constant = {}; + wg_size_constant.key = {"wg_size", WGPU_STRLEN}; + wg_size_constant.value = static_cast(wg_size); + + // Shared shader/layout; fresh pipeline+bind group per input (no double-free). + WGPUShaderSourceWGSL wgsl_desc = {}; + wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; + wgsl_desc.code = {kCatWGSL, WGPU_STRLEN}; + WGPUShaderModuleDescriptor shader_desc = {}; + shader_desc.nextInChain = &wgsl_desc.chain; + WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc); + + WGPUBindGroupLayoutEntry entries[5] = {}; + entries[0].binding = 0; + entries[0].visibility = WGPUShaderStage_Compute; + entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage; + entries[1].binding = 1; + entries[1].visibility = WGPUShaderStage_Compute; + entries[1].buffer.type = WGPUBufferBindingType_Storage; + entries[2].binding = 2; + entries[2].visibility = WGPUShaderStage_Compute; + entries[2].buffer.type = WGPUBufferBindingType_Uniform; + entries[3].binding = 3; + entries[3].visibility = WGPUShaderStage_Compute; + entries[3].buffer.type = WGPUBufferBindingType_Uniform; + entries[4].binding = 4; + entries[4].visibility = WGPUShaderStage_Compute; + entries[4].buffer.type = WGPUBufferBindingType_Uniform; + + WGPUBindGroupLayoutDescriptor bgl_desc = {}; + bgl_desc.entryCount = 5; + bgl_desc.entries = entries; + WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc); + + WGPUPipelineLayoutDescriptor pl_desc = {}; + pl_desc.bindGroupLayoutCount = 1; + pl_desc.bindGroupLayouts = &bgl; + WGPUPipelineLayout pipeline_layout = + wgpuDeviceCreatePipelineLayout(device, &pl_desc); + + uint32_t off_k = 0; + for (size_t k = 0; k < ids.size(); k++) { + const auto& in_tensor = graph.get_tensor(ids[k]); + + CatParams params = {}; + params.concat_dim = static_cast(dim); + params.off_k = off_k; + + WGPUBuffer in_meta_buf = + utils::make_uniform(device, &in_metas[k], sizeof(TensorMeta)); + WGPUBuffer params_buf = + utils::make_uniform(device, ¶ms, sizeof(CatParams)); + graph.add_uniform_buffer_bytes(sizeof(TensorMeta) + sizeof(CatParams)); + + WGPUComputePipelineDescriptor pipeline_desc = {}; + pipeline_desc.layout = pipeline_layout; + pipeline_desc.compute.module = shader; + pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN}; + pipeline_desc.compute.constantCount = 1; + pipeline_desc.compute.constants = &wg_size_constant; + WGPUComputePipeline pipeline = + wgpuDeviceCreateComputePipeline(device, &pipeline_desc); + + WGPUBindGroupEntry bg_entries[5] = {}; + bg_entries[0].binding = 0; + bg_entries[0].buffer = in_tensor.buffer; + bg_entries[0].size = in_tensor.nbytes; + bg_entries[1].binding = 1; + bg_entries[1].buffer = out_tensor.buffer; + bg_entries[1].size = out_tensor.nbytes; + bg_entries[2].binding = 2; + bg_entries[2].buffer = out_meta_buf; + bg_entries[2].size = sizeof(TensorMeta); + bg_entries[3].binding = 3; + bg_entries[3].buffer = in_meta_buf; + bg_entries[3].size = sizeof(TensorMeta); + bg_entries[4].binding = 4; + bg_entries[4].buffer = params_buf; + bg_entries[4].size = sizeof(CatParams); + + WGPUBindGroupDescriptor bg_desc = {}; + bg_desc.layout = bgl; + bg_desc.entryCount = 5; + bg_desc.entries = bg_entries; + WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc); + + graph.add_dispatch({pipeline, bind_group, wg_counts[k]}); + // Drop our refs; this input's bind group keeps its uniforms alive. + wgpuBufferRelease(in_meta_buf); + wgpuBufferRelease(params_buf); + off_k += static_cast(in_tensor.dims[dim]); + } + + wgpuShaderModuleRelease(shader); + wgpuBindGroupLayoutRelease(bgl); + wgpuPipelineLayoutRelease(pipeline_layout); + // Drop our ref to the shared out_meta; the bind groups keep it alive. + wgpuBufferRelease(out_meta_buf); +} + +} // namespace + +WEBGPU_REGISTER_OPERATORS { + WEBGPU_REGISTER_OP(aten.cat.default, cat_impl); +} + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/runtime/ops/cat/cat.wgsl b/backends/webgpu/runtime/ops/cat/cat.wgsl new file mode 100644 index 00000000000..3b1f4aaaa4d --- /dev/null +++ b/backends/webgpu/runtime/ops/cat/cat.wgsl @@ -0,0 +1,41 @@ +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; + +struct TensorMeta { + ndim: u32, + numel: u32, + sizes: vec4, + strides: vec4, +} +@group(0) @binding(2) var out_meta: TensorMeta; +@group(0) @binding(3) var in_meta: TensorMeta; + +struct Params { + concat_dim: u32, + off_k: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let in_bufi = gid.x; + if (in_bufi >= in_meta.numel) { + return; + } + + // Scatter: in coord -> out coord, concat dim shifted by off_k (Vulkan concat). + var rem = in_bufi; + var out_bufi: u32 = 0u; + for (var d: u32 = 0u; d < in_meta.ndim; d = d + 1u) { + let coord = rem / in_meta.strides[d]; + rem = rem % in_meta.strides[d]; + var out_coord = coord; + if (d == params.concat_dim) { + out_coord = coord + params.off_k; + } + out_bufi = out_bufi + out_coord * out_meta.strides[d]; + } + output[out_bufi] = input[in_bufi]; +} diff --git a/backends/webgpu/runtime/ops/cat/cat_wgsl.h b/backends/webgpu/runtime/ops/cat/cat_wgsl.h new file mode 100644 index 00000000000..94d7e2afdc8 --- /dev/null +++ b/backends/webgpu/runtime/ops/cat/cat_wgsl.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from cat.wgsl - DO NOT EDIT. +// wgsl-sha256: d1fcb4da7e32c6295b80d581c093b78d0a4b43a972fe2d5d9d94c4f9ae459f4c +inline constexpr const char* kCatWGSL = R"( +@group(0) @binding(0) var input: array; +@group(0) @binding(1) var output: array; + +struct TensorMeta { + ndim: u32, + numel: u32, + sizes: vec4, + strides: vec4, +} +@group(0) @binding(2) var out_meta: TensorMeta; +@group(0) @binding(3) var in_meta: TensorMeta; + +struct Params { + concat_dim: u32, + off_k: u32, +} +@group(0) @binding(4) var params: Params; + +override wg_size: u32 = 64u; + +@compute @workgroup_size(wg_size, 1, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + let in_bufi = gid.x; + if (in_bufi >= in_meta.numel) { + return; + } + + // Scatter: in coord -> out coord, concat dim shifted by off_k (Vulkan concat). + var rem = in_bufi; + var out_bufi: u32 = 0u; + for (var d: u32 = 0u; d < in_meta.ndim; d = d + 1u) { + let coord = rem / in_meta.strides[d]; + rem = rem % in_meta.strides[d]; + var out_coord = coord; + if (d == params.concat_dim) { + out_coord = coord + params.off_k; + } + out_bufi = out_bufi + out_coord * out_meta.strides[d]; + } + output[out_bufi] = input[in_bufi]; +} +)"; + +inline constexpr uint32_t kCatWorkgroupSizeX = 64; +inline constexpr uint32_t kCatWorkgroupSizeY = 1; +inline constexpr uint32_t kCatWorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu