Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ set(WEBGPU_SRCS
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
runtime/ops/mul/BinaryOp.cpp
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
runtime/ops/rope/RotaryEmbedding.cpp
runtime/ops/prepack/Prepack.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
11 changes: 6 additions & 5 deletions backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,21 @@ Error WebGPUBackend::execute(
const size_t num_outputs = graph->output_ids().size();

// Copy inputs from EValue tensors to GPU buffers
std::vector<std::pair<const void*, size_t>> inputs;
std::vector<InputData> inputs;
inputs.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
const auto& tensor = args[i]->toTensor();
inputs.emplace_back(tensor.const_data_ptr(), tensor.nbytes());
const bool host_is_int64 =
tensor.scalar_type() == executorch::aten::ScalarType::Long;
inputs.push_back({tensor.const_data_ptr(), tensor.nbytes(), host_is_int64});
}
graph->copy_inputs(inputs);

// Fail loud as a runtime Error so a throw never crosses the backend boundary.
try {
graph->copy_inputs(inputs);
graph->update_symints_from_inputs(inputs);
graph->propagate_resize();
} catch (const std::exception& e) {
ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what());
ET_LOG(Error, "WebGPU input copy / symint refresh failed: %s", e.what());
return Error::Internal;
}

Expand Down
247 changes: 194 additions & 53 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ namespace executorch::backends::webgpu {

namespace {

// Op name the AOT exporter emits for a prepacked constant (must match the
// serialized schema); compared in the prepack pre-scan below.
constexpr const char* kPrepackOpName = "et_vk.prepack.default";

size_t vk_datatype_size(vkgraph::VkDataType dtype) {
switch (dtype) {
case vkgraph::VkDataType::BOOL:
Expand All @@ -45,6 +49,19 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) {
}
}

bool vk_datatype_is_int(vkgraph::VkDataType dtype) {
switch (dtype) {
case vkgraph::VkDataType::BOOL:
case vkgraph::VkDataType::UINT8:
case vkgraph::VkDataType::INT8:
case vkgraph::VkDataType::INT32:
case vkgraph::VkDataType::INT64:
return true;
default:
return false;
}
}

} // namespace

WebGPUGraph::WebGPUGraph() = default;
Expand All @@ -61,7 +78,7 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
}

void WebGPUGraph::update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
const std::vector<InputData>& inputs) {
for (const auto& src : symint_sources_) {
int pos = -1;
for (size_t i = 0; i < input_ids_.size(); i++) {
Expand Down Expand Up @@ -100,8 +117,8 @@ void WebGPUGraph::update_symints_from_inputs(
// Reads the [0,..,index,..,0] element; symint sources are scalar-ish.
const int64_t offset = static_cast<int64_t>(index) * stride;
// elem_size back-derived from build-time numel (sources are static-shaped).
const void* host = inputs[pos].first;
const size_t elem_size = inputs[pos].second / static_cast<size_t>(numel);
const void* host = inputs[pos].data;
const size_t elem_size = inputs[pos].nbytes / static_cast<size_t>(numel);
int32_t val;
if (elem_size == sizeof(int64_t)) {
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
Expand Down Expand Up @@ -217,6 +234,10 @@ void WebGPUGraph::build(

const auto* graph = vkgraph::GetVkGraph(flatbuffer_data);

// .pte byte sources for prepack-time constant materialization (build-only).
constant_data_ = constant_data;
named_data_map_ = named_data_map;

// Phase 1: Create all values
const auto* values = graph->values();
const int num_vals = values ? values->size() : 0;
Expand All @@ -226,6 +247,42 @@ void WebGPUGraph::build(
ints_.resize(num_vals, 0);
doubles_.resize(num_vals, 0.0);
bools_.resize(num_vals, false);
value_lists_.resize(num_vals);

// Pre-scan the op chain: a constant may be DEFERRED (no eager GPU buffer; the
// prepack node materializes it once) only if it is a prepack source AND never
// a direct arg of a non-prepack op. ValueList args are expanded so a constant
// reached through a list still counts as a direct use.
std::unordered_set<int> prepack_src_ids;
std::unordered_set<int> direct_use_ids;
const auto* chain_prescan = graph->chain();
if (chain_prescan) {
for (unsigned ci = 0; ci < chain_prescan->size(); ci++) {
const auto* oc = chain_prescan->Get(ci);
const bool is_prepack = oc->name()->str() == kPrepackOpName;
const auto* a = oc->args();
if (!a) {
continue;
}
for (unsigned j = 0; j < a->size(); j++) {
int id = static_cast<int>(a->Get(j));
if (is_prepack && j == 0) {
prepack_src_ids.insert(id);
} else if (!is_prepack) {
direct_use_ids.insert(id);
const auto* v = values ? values->Get(id) : nullptr;
if (v && v->value_type() == vkgraph::GraphTypes::ValueList) {
const auto* items = v->value_as_ValueList()->items();
if (items) {
for (unsigned k = 0; k < items->size(); k++) {
direct_use_ids.insert(static_cast<int>(items->Get(k)));
}
}
}
}
}
}
}

for (int i = 0; i < num_vals; i++) {
const auto* val = values->Get(i);
Expand All @@ -248,56 +305,57 @@ void WebGPUGraph::build(
numel *= dims->Get(j);
}
}
tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype());
tensor.elem_size = vk_datatype_size(vk_tensor->datatype());
tensor.is_int = vk_datatype_is_int(vk_tensor->datatype());
tensor.nbytes = numel * tensor.elem_size;

int constant_id = vk_tensor->constant_id();
int mem_obj_id = vk_tensor->mem_obj_id();

// Constants always get dedicated buffers regardless of mem_obj_id
// Constants are dedicated. Every constant is recorded as a
// ConstantSource and materialized via materialize_constant (one
// CPU->GPU write); a constant consumed ONLY via prepack is deferred
// (no eager buffer -- its prepack node performs that one write).
if (constant_id >= 0 || mem_obj_id < 0) {
tensor_mem_obj_ids_[i] = -1;
WGPUBufferDescriptor buf_desc = {};
buf_desc.size = std::max(tensor.nbytes, size_t(4));
buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
WGPUBufferUsage_CopySrc;
buf_desc.mappedAtCreation = false;
tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc);

if (constant_id >= 0 && constant_data && tensor.nbytes > 0) {

if (constant_id >= 0) {
const auto* constants = graph->constants();
if (constants &&
constant_id < static_cast<int>(constants->size())) {
const auto* vk_bytes = constants->Get(constant_id);
if (vk_bytes->offset() != UINT64_MAX) {
const uint8_t* src = constant_data + vk_bytes->offset();
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, src, tensor.nbytes);
} else if (
vk_bytes->named_key() != nullptr &&
named_data_map != nullptr) {
// Constant stored in the PTE named-data map.
auto buf =
named_data_map->get_data(vk_bytes->named_key()->c_str());
if (!buf.ok()) {
throw std::runtime_error(
std::string("WebGPU: named constant '") +
vk_bytes->named_key()->c_str() +
"' not found in NamedDataMap");
}
if (buf->size() < tensor.nbytes) {
throw std::runtime_error(
std::string("WebGPU: named constant '") +
vk_bytes->named_key()->c_str() + "' undersized: have " +
std::to_string(buf->size()) + " bytes, need " +
std::to_string(tensor.nbytes));
}
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, buf->data(), tensor.nbytes);
buf->Free();
} else {
throw std::runtime_error(
"WebGPU: constant has no inline offset and no named-data key");
}
if (!constants ||
constant_id >= static_cast<int>(constants->size())) {
throw std::runtime_error(
"WebGPU: constant_id set but the constants table is missing "
"or the id is out of range");
}
const auto* vk_bytes = constants->Get(constant_id);
ConstantSource cs;
cs.nbytes = tensor.nbytes;
if (vk_bytes->offset() != UINT64_MAX) {
cs.inline_offset = vk_bytes->offset();
} else if (vk_bytes->named_key() != nullptr) {
cs.named_key = vk_bytes->named_key()->str();
} else {
throw std::runtime_error(
"WebGPU: constant has no inline offset and no named-data key");
}
constant_sources_[i] = std::move(cs);
}

// Defer constants consumed solely via prepack: skip the eager buffer.
const bool defer = constant_id >= 0 &&
prepack_src_ids.count(i) != 0 && direct_use_ids.count(i) == 0;
if (!defer) {
WGPUBufferDescriptor buf_desc = {};
buf_desc.size = std::max(tensor.nbytes, size_t(4));
buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
WGPUBufferUsage_CopySrc;
buf_desc.mappedAtCreation = false;
tensor.buffer = wgpuDeviceCreateBuffer(device_, &buf_desc);

// Same single CPU->GPU write the prepack node uses (no
// duplication).
if (constant_id >= 0) {
materialize_constant(i, tensor.buffer);
}
}
} else {
Expand Down Expand Up @@ -348,6 +406,16 @@ void WebGPUGraph::build(
add_uniform_buffer_bytes(kSymIntUniformBytes);
break;
}
case vkgraph::GraphTypes::ValueList: {
value_types_[i] = ValueType::ValueList;
const auto* items = val->value_as_ValueList()->items();
if (items) {
for (unsigned j = 0; j < items->size(); j++) {
value_lists_[i].push_back(static_cast<int>(items->Get(j)));
}
}
break;
}
default:
value_types_[i] = ValueType::Null;
break;
Expand Down Expand Up @@ -424,6 +492,47 @@ void WebGPUGraph::build(
webgpu_operator_registry().get_op_fn(op_name)(*this, args);
}
}

// Prepack nodes (Phase 3) materialized their constants directly into the
// consumer buffers via materialize_constant; no separate copy pass needed.
// The .pte bytes are freed right after build() returns (WebGPUBackend
// processed->Free()), so clear the build-only source pointers.
constant_data_ = nullptr;
named_data_map_ = nullptr;
}

void WebGPUGraph::materialize_constant(int const_value_id, WGPUBuffer dst) {
auto it = constant_sources_.find(const_value_id);
if (it == constant_sources_.end()) {
throw std::runtime_error(
"WebGPU: no source recorded for constant id " +
std::to_string(const_value_id));
}
const ConstantSource& cs = it->second;
if (cs.nbytes == 0) {
return;
}
if (cs.inline_offset != UINT64_MAX) {
if (constant_data_ == nullptr) {
throw std::runtime_error("WebGPU: inline constant data is null");
}
wgpuQueueWriteBuffer(
queue_, dst, 0, constant_data_ + cs.inline_offset, cs.nbytes);
} else if (!cs.named_key.empty() && named_data_map_ != nullptr) {
auto buf = named_data_map_->get_data(cs.named_key.c_str());
if (!buf.ok()) {
throw std::runtime_error(
"WebGPU: named constant '" + cs.named_key + "' not found");
}
if (buf->size() < cs.nbytes) {
throw std::runtime_error(
"WebGPU: named constant '" + cs.named_key + "' undersized");
}
wgpuQueueWriteBuffer(queue_, dst, 0, buf->data(), cs.nbytes);
buf->Free();
} else {
throw std::runtime_error("WebGPU: constant has no source");
}
}

WGPUShaderModule WebGPUGraph::get_or_create_shader(
Expand Down Expand Up @@ -484,16 +593,47 @@ WGPUBindGroupLayout WebGPUGraph::get_or_create_bgl(
return bgl;
}

void WebGPUGraph::copy_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
void WebGPUGraph::copy_inputs(const std::vector<InputData>& inputs) {
for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) {
if (inputs[i].second == 0) {
const InputData& in = inputs[i];
if (in.nbytes == 0) {
continue;
}
int tid = input_ids_[i];
const auto& tensor = tensors_[tid];
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, inputs[i].first, inputs[i].second);

// Fast path: host and GPU element types match byte-for-byte.
if (in.nbytes == tensor.nbytes) {
wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes);
continue;
}

// Narrow int64 host indices into the int32 buffer (mirrors Vulkan).
const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4;
if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) {
const size_t numel = tensor.nbytes / 4;
const int64_t* src = static_cast<const int64_t*>(in.data);
std::vector<int32_t> narrowed(numel);
for (size_t e = 0; e < numel; e++) {
#ifndef NDEBUG
// Index tensors (tokens/positions) are far below int32 range in
// practice; assert in debug that the narrowing is lossless.
if (static_cast<int32_t>(src[e]) != src[e]) {
throw std::runtime_error("WebGPU: int64 index overflows int32");
}
#endif
narrowed[e] = static_cast<int32_t>(src[e]);
}
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes);
continue;
}

throw std::runtime_error(
"WebGPU: unsupported input copy for input " + std::to_string(i) +
" (host " + std::to_string(in.nbytes) + " bytes" +
(in.host_is_int64 ? " int64" : "") + " vs buffer " +
std::to_string(tensor.nbytes) + " bytes)");
}
}

Expand Down Expand Up @@ -715,10 +855,11 @@ WebGPUMemoryStats WebGPUGraph::memory_stats() const {
for (size_t i = 0; i < value_types_.size(); i++) {
if (value_types_[i] == ValueType::Tensor && tensors_[i].nbytes > 0) {
stats.num_tensors++;
// Shared tensors are tracked via shared_buffer_sizes_
// Shared tensors are tracked via shared_buffer_sizes_; a deferred
// prepack-routed constant has no buffer (no GPU memory) -> not counted.
bool is_shared =
i < tensor_mem_obj_ids_.size() && tensor_mem_obj_ids_[i] >= 0;
if (!is_shared) {
if (!is_shared && tensors_[i].buffer != nullptr) {
stats.unshared_tensor_buffer_bytes += tensors_[i].nbytes;
}
}
Expand Down
Loading
Loading