Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);

const size_t num_inputs = compute_graph->inputs().size();
const size_t num_outputs = compute_graph->outputs().size();
bool should_propagate_resize = false;
#ifdef ET_EVENT_TRACER_ENABLED
runtime::EventTracer* event_tracer = context.event_tracer();
Expand Down Expand Up @@ -770,14 +771,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
"ETVK_COPY_OUTPUTS",
/* delegate_debug_id = */ -1);
#endif // ET_EVENT_TRACER_ENABLED
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
const size_t o = i + num_inputs;
const size_t output_offset = args.size() - num_outputs;
for (size_t i = 0; i < num_outputs; i++) {
const size_t o = output_offset + i;
const ValueRef oref = compute_graph->outputs()[i].value;
if (compute_graph->val_is_tensor(oref)) {
VK_CHECK_COND(args[o]->isTensor());
maybe_resize_output(compute_graph, i, args[o]->toTensor());
// args holds inputs directly followed by outputs, so the i'th output
// for compute_graph corresponds to the o'th arg
compute_graph->maybe_cast_and_copy_from_staging(
compute_graph->outputs()[i].staging,
args[o]->toTensor().mutable_data_ptr(),
Expand Down
20 changes: 12 additions & 8 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -452,14 +452,15 @@
const utils::AxisMapLayout axis_map_layout) {
ValueRef idx(static_cast<int>(values_.size()));
check_no_active_value_ptrs();
values_.emplace_back(api::vTensor(
context(),
sizes,
dtype,
storage_type,
memory_layout,
false,
axis_map_layout));
values_.emplace_back(
api::vTensor(
context(),
sizes,
dtype,
storage_type,
memory_layout,
false,
axis_map_layout));

if (shared_object_idx >= 0) {
get_shared_object(shared_object_idx).add_user(this, idx);
Expand Down Expand Up @@ -725,6 +726,9 @@
}

int32_t ComputeGraph::read_symint(const ValueRef idx) {
if (values_.at(idx).isInt()) {
return static_cast<int32_t>(values_.at(idx).toInt());
}
return get_symint(idx)->get();
}

Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ class ComputeGraph final {
if (value.isBool()) {
return static_cast<T>(value.toBool());
}
if (value.isSymInt()) {
return utils::safe_downcast<T>(read_symint(idx));
}
VK_THROW("Cannot extract scalar from Value with type ", value.type());
}

Expand Down
Loading