diff --git a/backends/cuda/cuda_partitioner.py b/backends/cuda/cuda_partitioner.py index e8f1276d5eb..264542a764b 100644 --- a/backends/cuda/cuda_partitioner.py +++ b/backends/cuda/cuda_partitioner.py @@ -10,6 +10,7 @@ from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip from executorch.exir._warnings import experimental from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.passes.propagate_device_pass import TARGET_DEVICE_COMPILE_SPEC_KEY @final @@ -19,7 +20,34 @@ class CudaPartitioner(AotiPartitioner): """ CUDA partitioner driven by AOTInductor backend. + + This partitioner adds a target_device compile spec to enable device info + propagation. The PropagateDevicePass will read this spec and mark delegate + output tensors with CUDA device type, which flows through to serialization. """ - def __init__(self, compile_spec: List[CompileSpec]) -> None: + def __init__( + self, + compile_spec: List[CompileSpec], + device_index: int = 0, + ) -> None: + """ + Initialize the CUDA partitioner. + + Args: + compile_spec: List of compile specs for the backend. + device_index: The CUDA device index (default: 0). This is used to + generate the target_device compile spec (e.g., "cuda:0"). + """ + # Add target_device compile spec for device propagation if not already present + has_target_device = any( + spec.key == TARGET_DEVICE_COMPILE_SPEC_KEY for spec in compile_spec + ) + if not has_target_device: + compile_spec = list(compile_spec) + [ + CompileSpec( + TARGET_DEVICE_COMPILE_SPEC_KEY, + f"cuda:{device_index}".encode("utf-8"), + ) + ] super().__init__(CudaBackend.__name__, compile_spec) diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 5960142d2b4..39071f731a1 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -403,6 +403,26 @@ class ET_EXPERIMENTAL CudaBackend final n_outputs, args.size()) + // Verify device info on all memory-planned, ET-driven IO tensors. + // All input and output tensors should have device_type = CUDA, which + // is set during serialization by PropagateDevicePass based on the + // target_device compile spec from CudaPartitioner. + // + // Note: At this stage, the tensor memory is still on CPU. The device_type + // is metadata indicating where the tensor *should* reside. The backend + // is responsible for copying data to the actual CUDA device. + for (size_t i = 0; i < n_inputs + n_outputs; i++) { + auto* tensor = &(args[i]->toTensor()); + auto device_type = tensor->unsafeGetTensorImpl()->device_type(); + ET_CHECK_OR_RETURN_ERROR( + device_type == executorch::runtime::etensor::DeviceType::CUDA, + InvalidArgument, + "Tensor %zu expected device_type=CUDA (1), got %d. " + "Device info may not be properly propagated from CudaPartitioner.", + i, + static_cast(device_type)); + } + // NOTE: ExecuTorch tensors may be on CPU or GPU due to the skip-copy // optimization. We need to create GPU copies for CUDA kernel execution // using SlimTensor. diff --git a/backends/cuda/runtime/shims/tests/targets.bzl b/backends/cuda/runtime/shims/tests/targets.bzl index cf8bc5c93e5..03860e0f09a 100644 --- a/backends/cuda/runtime/shims/tests/targets.bzl +++ b/backends/cuda/runtime/shims/tests/targets.bzl @@ -41,3 +41,6 @@ def define_common_targets(): cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle") cuda_shim_cpp_unittest("aoti_torch_item_bool") cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out") + + # ETensor device info test (uses cuda shims for tensor creation) + cuda_shim_cpp_unittest("etensor_device_info") diff --git a/backends/cuda/runtime/shims/tests/test_etensor_device_info.cpp b/backends/cuda/runtime/shims/tests/test_etensor_device_info.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/cuda/tests/test_cuda_export.py b/backends/cuda/tests/test_cuda_export.py index ff4a9313545..71d2bc4ed12 100644 --- a/backends/cuda/tests/test_cuda_export.py +++ b/backends/cuda/tests/test_cuda_export.py @@ -325,3 +325,121 @@ def test_triton_kernel_mode_off(self): edge_program_manager, "SDPA kernel export with triton_kernel_mode=OFF failed", ) + + def test_device_info_propagated_to_cuda_delegate_outputs(self): + """ + Test that device info is correctly propagated from export to serialization + for CUDA delegate outputs. + + This verifies the device propagation flow: + 1. CudaPartitioner adds target_device="cuda:0" CompileSpec + 2. PropagateDevicePass sets TensorSpec.device = CUDA for delegate outputs + 3. Emitter serializes device info into ExtraTensorInfo.device_type + 4. Serialized tensors have device_type = DeviceType.CUDA + + Note: At this stage, the tensor memory is still on CPU. The CUDA backend + will copy data to GPU device at runtime. Device info tagging is the first + step toward full device-aware memory allocation. + """ + from executorch.exir import schema + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + module = AddModule() + module.eval() + inputs = (torch.randn(2, 3), torch.randn(2, 3)) + + # Export to CUDA with full pipeline + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + self.assertIsNotNone(edge_program_manager, "CUDA export failed") + + # Convert to ExecutorTorch and access the serialized program + et_prog = edge_program_manager.to_executorch() + program = et_prog._emitter_output.program + + # Get the execution plan and verify delegate exists + plan = program.execution_plan[0] + self.assertGreater( + len(plan.delegates), + 0, + "Expected at least one delegate in the execution plan", + ) + + # Find all serialized tensors with CUDA device type + cuda_tensors = [] + for value in plan.values: + if isinstance(value.val, schema.Tensor): + tensor = value.val + if ( + tensor.extra_tensor_info is not None + and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA + ): + cuda_tensors.append(tensor) + + # The add operation produces 1 output tensor that should be tagged as CUDA + # because it's a delegate output from the CUDA backend + self.assertGreater( + len(cuda_tensors), + 0, + "Expected at least 1 tensor with CUDA device type for delegate output. " + "Device info should be propagated from CudaPartitioner through " + "PropagateDevicePass to the serialized tensor.", + ) + + def test_input_tensors_remain_cpu_device(self): + """ + Test that input tensors (not delegate outputs) remain on CPU device. + + Input tensors are provided by the user and are not produced by delegates, + so they should not be tagged with CUDA device info. Only delegate outputs + should have device info propagated. + """ + from executorch.exir import schema + + class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + module = AddModule() + module.eval() + inputs = (torch.randn(2, 3), torch.randn(2, 3)) + + # Export to CUDA + edge_program_manager = self._export_to_cuda_with_lower(module, inputs) + et_prog = edge_program_manager.to_executorch() + program = et_prog._emitter_output.program + + plan = program.execution_plan[0] + + # Count tensors by device type + cpu_tensors = [] + cuda_tensors = [] + + for value in plan.values: + if isinstance(value.val, schema.Tensor): + tensor = value.val + if ( + tensor.extra_tensor_info is not None + and tensor.extra_tensor_info.device_type == schema.DeviceType.CUDA + ): + cuda_tensors.append(tensor) + else: + # Either no extra_tensor_info or device_type is CPU (default) + cpu_tensors.append(tensor) + + # We should have both CPU tensors (inputs) and CUDA tensors (delegate outputs) + # The exact count depends on the model structure, but: + # - Inputs should be CPU (2 input tensors) + # - Delegate outputs should be CUDA (1 output tensor) + self.assertGreater( + len(cpu_tensors), + 0, + "Expected CPU tensors for model inputs", + ) + self.assertGreater( + len(cuda_tensors), + 0, + "Expected CUDA tensors for delegate outputs", + )