Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(WEBGPU_SRCS
runtime/ops/unsqueeze/Unsqueeze.cpp
runtime/ops/slice/Slice.cpp
runtime/ops/permute/Permute.cpp
runtime/ops/cat/Cat.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
12 changes: 12 additions & 0 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ void WebGPUGraph::build(
tensor_mem_obj_ids_.resize(num_vals, -1);
ints_.resize(num_vals, 0);
int_lists_.resize(num_vals);
value_lists_.resize(num_vals);
doubles_.resize(num_vals, 0.0);
bools_.resize(num_vals, false);

Expand Down Expand Up @@ -326,6 +327,17 @@ void WebGPUGraph::build(
}
break;
}
case vkgraph::GraphTypes::ValueList: {
value_types_[i] = ValueType::ValueList;
const auto* items = val->value_as_ValueList()->items();
if (items) {
value_lists_[i].reserve(items->size());
for (unsigned j = 0; j < items->size(); j++) {
value_lists_[i].push_back(static_cast<int>(items->Get(j)));
}
}
break;
}
case vkgraph::GraphTypes::Double: {
value_types_[i] = ValueType::Double;
doubles_[i] = val->value_as_Double()->double_val();
Expand Down
8 changes: 7 additions & 1 deletion backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ class WebGPUGraph {
const std::vector<int64_t>& get_int_list(int id) const {
return int_lists_[id];
}
// Member value ids of a serialized ValueList (op multi-output list).
const std::vector<int>& get_value_list(int id) const {
return value_lists_[id];
}
bool get_bool(int id) const {
return bools_[id];
}
Expand Down Expand Up @@ -219,7 +223,8 @@ class WebGPUGraph {
Null,
String,
SymInt,
IntList
IntList,
ValueList
};

ValueType get_value_type(int id) const {
Expand All @@ -237,6 +242,7 @@ class WebGPUGraph {
std::vector<WebGPUTensor> tensors_;
std::vector<int64_t> ints_;
std::vector<std::vector<int64_t>> int_lists_;
std::vector<std::vector<int>> value_lists_;
std::vector<double> doubles_;
std::vector<bool> bools_;

Expand Down
218 changes: 218 additions & 0 deletions backends/webgpu/runtime/ops/cat/Cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
#include <executorch/backends/webgpu/runtime/ops/cat/cat_wgsl.h>

#include <webgpu/webgpu.h>

#include <cstdint>
#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

struct CatParams {
uint32_t concat_dim;
uint32_t off_k;
uint32_t _pad[2];
};
static_assert(
sizeof(CatParams) == 16,
"CatParams must match the WGSL Params uniform (16-byte aligned)");

// cat: 1 dispatch/input -> disjoint out slab at host off_k (Vulkan concat).
void cat_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// args: [tensors (ValueList), dim, out].
const int list_id = args.at(0);
const int out_id = args.at(args.size() - 1);

if (graph.get_value_type(list_id) != WebGPUGraph::ValueType::ValueList) {
throw std::runtime_error("cat: tensors arg is not a ValueList");
}
if (graph.get_value_type(args.at(1)) != WebGPUGraph::ValueType::Int) {
throw std::runtime_error("cat: dim arg is not a static Int");
}
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("cat: out arg is not a tensor");
}

WGPUDevice device = graph.device();
const std::vector<int>& ids = graph.get_value_list(list_id);
if (ids.empty()) {
throw std::runtime_error("cat: empty input list");
}

const auto& out_tensor = graph.get_tensor(out_id);
const int ndim = static_cast<int>(out_tensor.dims.size());

int64_t dim = graph.get_int(args.at(1));
if (dim < 0) {
dim += ndim;
}
if (dim < 0 || dim >= ndim) {
throw std::runtime_error("cat: dim out of range");
}

// Workgroup size is invariant across inputs: clamp once, share the constant.
uint32_t wg_size = utils::clamp_workgroup_size(device, kCatWorkgroupSizeX);

// Validate + cache input meta/wgc BEFORE any GPU alloc (no leak on throw).
std::vector<TensorMeta> in_metas(ids.size());
std::vector<uint32_t> wg_counts(ids.size());
int64_t concat_sum = 0;
for (size_t k = 0; k < ids.size(); k++) {
const int id = ids[k];
if (graph.get_value_type(id) != WebGPUGraph::ValueType::Tensor) {
throw std::runtime_error("cat: input list element is not a tensor");
}
const auto& in_tensor = graph.get_tensor(id);
if (static_cast<int>(in_tensor.dims.size()) != ndim) {
throw std::runtime_error("cat: input rank != output rank");
}
for (int d = 0; d < ndim; d++) {
if (d != dim && in_tensor.dims[d] != out_tensor.dims[d]) {
throw std::runtime_error("cat: non-concat dim size mismatch");
}
}
fill_tensor_meta(in_tensor, &in_metas[k]);
if (in_tensor.nbytes !=
static_cast<size_t>(in_metas[k].numel) * sizeof(float)) {
throw std::runtime_error("cat: non-fp32 input (nbytes != numel * 4)");
}
wg_counts[k] = utils::compute_1d_workgroup_count(
device, in_metas[k].numel, wg_size, "cat");
concat_sum += in_tensor.dims[dim];
}
if (concat_sum != out_tensor.dims[dim]) {
throw std::runtime_error("cat: concat dim sizes do not sum to output");
}

TensorMeta out_meta;
fill_tensor_meta(out_tensor, &out_meta);
if (out_tensor.nbytes !=
static_cast<size_t>(out_meta.numel) * sizeof(float)) {
throw std::runtime_error("cat: non-fp32 output (nbytes != numel * 4)");
}

WGPUBuffer out_meta_buf =
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
graph.add_uniform_buffer_bytes(sizeof(TensorMeta));

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

// Shared shader/layout; fresh pipeline+bind group per input (no double-free).
WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kCatWGSL, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

WGPUBindGroupLayoutEntry entries[5] = {};
entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;
entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_Uniform;
entries[3].binding = 3;
entries[3].visibility = WGPUShaderStage_Compute;
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
entries[4].binding = 4;
entries[4].visibility = WGPUShaderStage_Compute;
entries[4].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 5;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

uint32_t off_k = 0;
for (size_t k = 0; k < ids.size(); k++) {
const auto& in_tensor = graph.get_tensor(ids[k]);

CatParams params = {};
params.concat_dim = static_cast<uint32_t>(dim);
params.off_k = off_k;

WGPUBuffer in_meta_buf =
utils::make_uniform(device, &in_metas[k], sizeof(TensorMeta));
WGPUBuffer params_buf =
utils::make_uniform(device, &params, sizeof(CatParams));
graph.add_uniform_buffer_bytes(sizeof(TensorMeta) + sizeof(CatParams));

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[5] = {};
bg_entries[0].binding = 0;
bg_entries[0].buffer = in_tensor.buffer;
bg_entries[0].size = in_tensor.nbytes;
bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;
bg_entries[2].binding = 2;
bg_entries[2].buffer = out_meta_buf;
bg_entries[2].size = sizeof(TensorMeta);
bg_entries[3].binding = 3;
bg_entries[3].buffer = in_meta_buf;
bg_entries[3].size = sizeof(TensorMeta);
bg_entries[4].binding = 4;
bg_entries[4].buffer = params_buf;
bg_entries[4].size = sizeof(CatParams);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 5;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, wg_counts[k]});
// Drop our refs; this input's bind group keeps its uniforms alive.
wgpuBufferRelease(in_meta_buf);
wgpuBufferRelease(params_buf);
off_k += static_cast<uint32_t>(in_tensor.dims[dim]);
}

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref to the shared out_meta; the bind groups keep it alive.
wgpuBufferRelease(out_meta_buf);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.cat.default, cat_impl);
}

} // namespace executorch::backends::webgpu
41 changes: 41 additions & 0 deletions backends/webgpu/runtime/ops/cat/cat.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;

struct Params {
concat_dim: u32,
off_k: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let in_bufi = gid.x;
if (in_bufi >= in_meta.numel) {
return;
}

// Scatter: in coord -> out coord, concat dim shifted by off_k (Vulkan concat).
var rem = in_bufi;
var out_bufi: u32 = 0u;
for (var d: u32 = 0u; d < in_meta.ndim; d = d + 1u) {
let coord = rem / in_meta.strides[d];
rem = rem % in_meta.strides[d];
var out_coord = coord;
if (d == params.concat_dim) {
out_coord = coord + params.off_k;
}
out_bufi = out_bufi + out_coord * out_meta.strides[d];
}
output[out_bufi] = input[in_bufi];
}
65 changes: 65 additions & 0 deletions backends/webgpu/runtime/ops/cat/cat_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from cat.wgsl - DO NOT EDIT.
// wgsl-sha256: d1fcb4da7e32c6295b80d581c093b78d0a4b43a972fe2d5d9d94c4f9ae459f4c
inline constexpr const char* kCatWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
@group(0) @binding(3) var<uniform> in_meta: TensorMeta;

struct Params {
concat_dim: u32,
off_k: u32,
}
@group(0) @binding(4) var<uniform> params: Params;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let in_bufi = gid.x;
if (in_bufi >= in_meta.numel) {
return;
}

// Scatter: in coord -> out coord, concat dim shifted by off_k (Vulkan concat).
var rem = in_bufi;
var out_bufi: u32 = 0u;
for (var d: u32 = 0u; d < in_meta.ndim; d = d + 1u) {
let coord = rem / in_meta.strides[d];
rem = rem % in_meta.strides[d];
var out_coord = coord;
if (d == params.concat_dim) {
out_coord = coord + params.off_k;
}
out_bufi = out_bufi + out_coord * out_meta.strides[d];
}
output[out_bufi] = input[in_bufi];
}
)";

inline constexpr uint32_t kCatWorkgroupSizeX = 64;
inline constexpr uint32_t kCatWorkgroupSizeY = 1;
inline constexpr uint32_t kCatWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading