-
Notifications
You must be signed in to change notification settings - Fork 971
[executorch] Propagate device metadata from partitioner result onto TensorSpecs #18078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
ec674a5
a86edd6
b1ae53d
d10f8e4
774d616
6ba5e06
a05c44c
7973504
6b310ff
eaec5f2
66f3f8c
29286b9
4708adb
d825432
4ae5949
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,6 +32,8 @@ | |
| create_submodule_from_nodes, | ||
| LoweredBackendModule, | ||
| ) | ||
| from executorch.exir.passes.propagate_device_pass import PropagateDevicePass | ||
| from executorch.exir.passes.spec_prop_pass import make_spec | ||
| from executorch.exir.program._fake_program import ( | ||
| get_fake_program, | ||
| update_to_real_program, | ||
|
|
@@ -242,6 +244,9 @@ | |
| call_delegate_node.meta["val"] = [ | ||
| out_arg.meta["val"] for out_arg in submodule_output_node.args[0] | ||
| ] | ||
| call_delegate_node.meta["spec"] = tuple( | ||
| make_spec(out_arg.meta["val"]) for out_arg in submodule_output_node.args[0] | ||
| ) | ||
| call_submodule_node.replace_all_uses_with(call_delegate_node) | ||
| owning_graph_module.graph.erase_node(call_submodule_node) | ||
| if is_submodule: | ||
|
|
@@ -427,6 +432,9 @@ | |
| tagged_exported_program, | ||
| ) | ||
|
|
||
| # Propagate device metadata from delegate CompileSpecs onto TensorSpecs | ||
| PropagateDevicePass()(tagged_graph_module) | ||
|
|
||
| # Partitioner added delegation tags to the graph module nodes, | ||
| # we make sure to remove them after we finished partition_and_lower | ||
| for node in tagged_graph_module.graph.nodes: | ||
|
|
@@ -765,6 +773,13 @@ | |
| method_to_tagged_exported_program, | ||
| ) | ||
|
|
||
| # Propagate device metadata from delegate CompileSpecs onto TensorSpecs | ||
| for ( | ||
| method_name, | ||
|
Check warning on line 778 in exir/backend/backend_api.py
|
||
| tagged_exported_program, | ||
| ) in method_to_tagged_exported_program.items(): | ||
| PropagateDevicePass()(tagged_exported_program.graph_module) | ||
|
|
||
| for method_name in method_to_edge_program.keys(): | ||
| if method_name in method_to_tagged_exported_program: | ||
| tagged_exported_program = method_to_tagged_exported_program[method_name] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| # pyre-strict | ||
|
|
||
| import logging | ||
| from typing import Optional | ||
|
|
||
| import executorch.exir.schema as schema | ||
|
|
||
| import torch | ||
| from executorch.exir.delegate import executorch_call_delegate | ||
| from executorch.exir.lowered_backend_module import LoweredBackendModule | ||
| from executorch.exir.tensor import TensorSpec | ||
| from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
|
|
||
| logger: logging.Logger = logging.getLogger(__name__) | ||
|
|
||
| # CompileSpec key convention for specifying the target device. | ||
| # Partitioners that target a specific device should include a CompileSpec entry | ||
| # with this key and a value encoding the device string (e.g., b"cuda:0"). | ||
| TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device" | ||
|
|
||
| # Mapping from torch.device type strings to schema.DeviceType. | ||
| _DEVICE_STR_TO_ET_DEVICE: dict[str, schema.DeviceType] = { | ||
| "cpu": schema.DeviceType.CPU, | ||
| "cuda": schema.DeviceType.CUDA, | ||
| } | ||
|
|
||
|
|
||
| def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]: | ||
| """ | ||
| Parse a target_device CompileSpec value (e.g., b"cuda:0") into | ||
| (DeviceType, device_index). | ||
| """ | ||
| device_str = value.decode("utf-8") | ||
| torch_device = torch.device(device_str) | ||
| device_type = _DEVICE_STR_TO_ET_DEVICE.get(torch_device.type, schema.DeviceType.CPU) | ||
| device_index = torch_device.index if torch_device.index is not None else 0 | ||
| return device_type, device_index | ||
|
|
||
|
|
||
| def _get_lowered_module( | ||
| graph_module: torch.fx.GraphModule, | ||
| delegate_call_node: torch.fx.Node, | ||
| ) -> Optional[LoweredBackendModule]: | ||
| """ | ||
| Given an executorch_call_delegate node, retrieve the associated | ||
| LoweredBackendModule from the graph module. | ||
| The first argument to executorch_call_delegate is a get_attr node | ||
| whose target names the LoweredBackendModule attribute. | ||
| """ | ||
| if len(delegate_call_node.args) < 1: | ||
| return None | ||
| lowered_node = delegate_call_node.args[0] | ||
| if not isinstance(lowered_node, torch.fx.Node) or lowered_node.op != "get_attr": | ||
| return None | ||
| lowered_module = getattr(graph_module, lowered_node.target, None) | ||
| if isinstance(lowered_module, LoweredBackendModule): | ||
| return lowered_module | ||
| return None | ||
|
|
||
|
|
||
| def _get_target_device_from_compile_specs( | ||
| lowered_module: LoweredBackendModule, | ||
| ) -> Optional[tuple[schema.DeviceType, int]]: | ||
| """ | ||
| Look for a CompileSpec with key TARGET_DEVICE_COMPILE_SPEC_KEY and return | ||
| the corresponding (DeviceType, device_index), or None if not found. | ||
| """ | ||
| for spec in lowered_module.compile_specs: | ||
| if spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY: | ||
| return _parse_device_spec_value(spec.value) | ||
| return None | ||
|
|
||
|
|
||
| def _set_device_on_spec( | ||
| spec: TensorSpec, | ||
| device_type: schema.DeviceType, | ||
| ) -> None: | ||
| """Set the device attribute on a TensorSpec.""" | ||
| spec.device = device_type | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these fields already in the TensorSpec class definition? Are they initialized to just cpu and 0?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes and yes. |
||
|
|
||
|
|
||
| class PropagateDevicePass(PassBase): | ||
| """ | ||
| After to_backend, walk the graph and set device metadata on TensorSpecs | ||
| based on partitioner-assigned delegation info. | ||
|
|
||
| Rules: | ||
| 1. Delegated nodes: Output tensors of a delegate call are marked with the | ||
| target device derived from the delegate's CompileSpec (key="target_device"). | ||
| 2. Non-delegated nodes: Remain on CPU (default). | ||
| 3. Getitem nodes that extract from a delegate call inherit the device from | ||
| the delegate call's output spec at the corresponding index. | ||
| """ | ||
|
|
||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
| changed = False | ||
| for node in graph_module.graph.nodes: | ||
| if node.op == "call_function" and node.target == executorch_call_delegate: | ||
| lowered_module = _get_lowered_module(graph_module, node) | ||
| if lowered_module is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should throw here no?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me throw here. I don't think it will be None. |
||
| continue | ||
|
|
||
| result = _get_target_device_from_compile_specs(lowered_module) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This effectively assumes that we know the device 'name' AoT. In theory, we can have a multi-device delegate then the runtime might interpret this name differently and that can cause some confusion i.e I am not sure about using generic names like 'gpu' but also not sure about following PyTorch's eager/jit style naming convention where you won't switch devices underneath.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May I have your suggestions on the executorch device name? Currently we set up the device name AOT and intentionally decouple dour device attribute with pytorch/pytorch device concept; we created a enum in the etensor schema for all devices we are supporting right now. In this way we can support as much as device as we want. For the situaton you mentioned, if other backend like vulken need its own gpu device, they should add a new one to the enum. We should avoid using generic names like 'gpu'.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Multi-device graph serialization will necessitate multiple graphs. We can maybe make an exception for input tensors, but for any intermediate the runtime needs to know what the device its loading intermediates onto. Device is fixed at export aot. If you want to have some generic shader style lib where the gpu type is decided lazily then you will have to use a generic key like gpu. |
||
| if result is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does it not return cpu by default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value for every tensor is |
||
| continue | ||
|
|
||
| target_device_type, _device_index = result | ||
|
|
||
| # Mark all output TensorSpecs of this delegate call node | ||
| specs = node.meta.get("spec") | ||
| if specs is None: | ||
| continue | ||
|
|
||
| if isinstance(specs, TensorSpec): | ||
| _set_device_on_spec(specs, target_device_type) | ||
| changed = True | ||
| elif isinstance(specs, (tuple, list)): | ||
| for s in specs: | ||
| if isinstance(s, TensorSpec): | ||
| _set_device_on_spec(s, target_device_type) | ||
| changed = True | ||
|
|
||
| logger.debug( | ||
| "PropagateDevicePass: set device=%s on delegate node %s " | ||
| "(backend=%s)", | ||
| target_device_type, | ||
| node.name, | ||
| lowered_module.backend_id, | ||
| ) | ||
|
|
||
| # Second pass: propagate device through getitem nodes that extract | ||
| # individual outputs from a delegate call. | ||
| for node in graph_module.graph.nodes: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just do 1 pass. You can look at users of the delegate node to find the getitem nodes.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we can but i feel like two passes are more structural: one is specific for delegate input and the other is for delegate output |
||
| if node.op == "call_function" and node.target.__name__ == "getitem": | ||
| source_node = node.args[0] | ||
| if ( | ||
| isinstance(source_node, torch.fx.Node) | ||
| and source_node.op == "call_function" | ||
| and source_node.target == executorch_call_delegate | ||
| ): | ||
| spec = node.meta.get("spec") | ||
| source_specs = source_node.meta.get("spec") | ||
| idx = node.args[1] | ||
| if ( | ||
| spec is not None | ||
| and isinstance(spec, TensorSpec) | ||
| and source_specs is not None | ||
| and isinstance(source_specs, (tuple, list)) | ||
| and isinstance(idx, int) | ||
| and idx < len(source_specs) | ||
| ): | ||
| source_spec = source_specs[idx] | ||
| if isinstance(source_spec, TensorSpec): | ||
| _set_device_on_spec(spec, source_spec.device) | ||
| changed = True | ||
|
|
||
| return PassResult(graph_module, changed) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this type of util really should be placed in a single spot. There are other things like this in the passes. Lets take it as a follow up to have claude just search for generic utils like this and centralize them