diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 29c2ed076a..8671dd5860 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -91,6 +91,7 @@ def cross_compile_for_windows( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -171,6 +172,7 @@ def cross_compile_for_windows( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -318,6 +320,7 @@ def cross_compile_for_windows( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -432,6 +435,7 @@ def compile( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -527,6 +531,7 @@ def compile( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -706,6 +711,7 @@ def compile( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, @@ -1226,6 +1232,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH, dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY, + cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES, reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES, @@ -1302,6 +1309,7 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX. runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled". lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage @@ -1458,6 +1466,7 @@ def convert_exported_program_to_serialized_trt_engine( "timing_cache_path": timing_cache_path, "runtime_cache_path": runtime_cache_path, "dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy, + "cuda_graph_strategy": cuda_graph_strategy, "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 784066cc75..00bc39bce5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -70,6 +70,7 @@ DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" +CUDA_GRAPH_STRATEGY = "disabled" USE_PYTHON_RUNTIME = False if platform.system() == "Linux": diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 3fe18e0a0d..694b1a7000 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -17,6 +17,7 @@ AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, CPU_MEMORY_BUDGET, + CUDA_GRAPH_STRATEGY, DECOMPOSE_ATTENTION, DISABLE_TF32, DLA_GLOBAL_DRAM_SIZE, @@ -95,6 +96,7 @@ class CompilationSettings: timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning). runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT. dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy". + cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled". cache_built_engines (bool): Whether to save the compiled TRT engines to storage reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. @@ -150,6 +152,7 @@ class CompilationSettings: dynamic_shapes_kernel_specialization_strategy: str = ( DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY ) + cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 3fae323704..537444a0b9 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -127,6 +127,59 @@ def __del__(self) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable + def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: + """Verify every TRT submodule is safe for monolithic stream capture. + + Whole-graph CUDA graph mode wraps mixed TRT + PyTorch ops in a + single outer ``torch.cuda.CUDAGraph`` capture. On TRT-RTX, each + engine must opt out of RTX-native CUDA graphs (which would + interfere with the outer capture) and must pass the + ``IExecutionContext.is_stream_capturable`` check. Raises + ``RuntimeError`` if any TRT engine is not monolithically + capturable. No-op on non-RTX builds. + """ + from torch_tensorrt._features import ENABLED_FEATURES + + if not ENABLED_FEATURES.tensorrt_rtx: + return + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + TorchTensorRTModule, + ) + from torch_tensorrt.dynamo.runtime._TRTEngine import ( + TRTEngine, + _get_cuda_graph_strategy, + ) + + for name, mod in self.compiled_module.named_modules(): + if not ( + isinstance(mod, TorchTensorRTModule) + and isinstance(mod.engine, TRTEngine) + ): + continue + engine = mod.engine + if not engine._is_monolithic_capturable(stream): + raise RuntimeError( + f"CUDA graph capture failed: TRT submodule '{name}' is " + "not monolithically capturable (lazy kernel " + "specialization or non-capturable stream). Whole-graph " + "CUDA graph mode with mixed TRT + PyTorch ops requires " + "all TRT engines to be capturable. Consider using " + "cuda_graph_strategy='whole_graph_capture' with " + "set_cudagraphs_mode(True) instead of enable_cudagraphs()." + ) + # Disable RTX-native CUDA graphs on this engine so they don't + # interfere with the outer monolithic capture. + if engine._rtx_native_cudagraphs: + engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "disabled" + ) + engine.context = engine._create_execution_context() + engine._rtx_native_cudagraphs = False + logger.info( + f"Disabled RTX-native CUDA graphs for '{name}' " + "(using outer monolithic capture instead)" + ) + def forward( self, *args: Any, **kwargs: Any ) -> torch.Tensor | Tuple[torch.Tensor, ...]: @@ -212,6 +265,7 @@ def forward( with torch.cuda.stream(self._engine_stream): if need_cudagraphs_record: + self._check_monolithic_capturability(self._engine_stream) self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): self._output_buffers = self.compiled_module(*args, **kwargs) diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index c9c6f8a433..e8496f2d3d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -11,6 +11,7 @@ import base64 import copy import logging +import os import pickle import tempfile from contextlib import nullcontext @@ -44,6 +45,7 @@ deserialize_binding_names, parse_device_info, ) +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( _is_switch_required, @@ -56,6 +58,24 @@ logger = logging.getLogger(__name__) + +def _get_dynamic_shapes_kernel_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT enum. Only meaningful on TensorRT-RTX builds.""" + return { + "lazy": trt.DynamicShapesKernelSpecializationStrategy.LAZY, + "eager": trt.DynamicShapesKernelSpecializationStrategy.EAGER, + "none": trt.DynamicShapesKernelSpecializationStrategy.NONE, + }.get(strategy_str, trt.DynamicShapesKernelSpecializationStrategy.LAZY) + + +def _get_cuda_graph_strategy(strategy_str: str) -> Any: + """Map strategy string to TRT CudaGraphStrategy enum. Only meaningful on RTX.""" + return { + "disabled": trt.CudaGraphStrategy.DISABLED, + "whole_graph_capture": trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + }.get(strategy_str, trt.CudaGraphStrategy.DISABLED) + + # --------------------------------------------------------------------------- # TRT I/O helpers # --------------------------------------------------------------------------- @@ -219,7 +239,13 @@ def __init__( torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - self._runtime_config = None + # Initialized to ``None`` here so the destructor can safely save the + # cache even if ``_setup_engine`` never runs. + self.runtime_config: Any = None + self.runtime_cache: Any = None + # When true, ``_execute_standard`` must skip manual torch.cuda.CUDAGraph + # capture because TRT-RTX handles it internally. + self._rtx_native_cudagraphs: bool = False # NCCL communicator is bound lazily on the first forward pass for # engines compiled with native multi-device collective layers. self._nccl_comm: Optional[Any] = None @@ -228,7 +254,7 @@ def __init__( self._setup_engine() def __del__(self) -> None: - self.reset_captured_graph() + self.close() def __deepcopy__(self, memo: dict[int, Any]) -> "TRTEngine": """Rebuild from serialized layout so ``copy.deepcopy`` skips unpickleable TRT handles.""" @@ -279,7 +305,11 @@ def __setstate__(self, state: Any) -> None: torch_tensorrt.runtime.get_cudagraphs_mode() ) self.resource_allocation_strategy = 0 - self._runtime_config = None + # See ``__init__`` for the rationale: pre-init these so a destructor + # firing on a partially-loaded engine never trips an ``AttributeError``. + self.runtime_config = None + self.runtime_cache = None + self._rtx_native_cudagraphs = False # NCCL communicators cannot be pickled; rebind lazily on the next # forward pass via setup_nccl_comm(). self._nccl_comm = None @@ -371,14 +401,21 @@ def get_serialized_metadata(self) -> str: return self.serialized_metadata def close(self) -> None: - """Release CUDA graph resources (called explicitly or via __del__).""" + """Persist the runtime cache and release CUDA graph resources.""" + self._save_runtime_cache() self.reset_captured_graph() def _create_execution_context(self) -> trt.IExecutionContext: - strategy = trt.ExecutionContextAllocationStrategy.STATIC - if self.resource_allocation_strategy: - strategy = trt.ExecutionContextAllocationStrategy.USER_MANAGED - context = self.cuda_engine.create_execution_context(strategy) + if ENABLED_FEATURES.tensorrt_rtx: + assert self.runtime_config is not None + context = self.cuda_engine.create_execution_context(self.runtime_config) + else: + strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) + context = self.cuda_engine.create_execution_context(strategy) assert context is not None, "Failed to create execution context" return context @@ -393,6 +430,15 @@ def _setup_engine(self) -> None: logger.debug(f"Weight streaming budget set to {budget_bytes}B") self.cuda_engine.weight_streaming_budget_v2 = budget_bytes + # On TensorRT-RTX, build the IRuntimeConfig (runtime cache, + # dynamic-shape kernel specialization strategy, and CUDA graph + # strategy) up front so the one-and-only execution context picks it up. + if ENABLED_FEATURES.tensorrt_rtx: + self._setup_runtime_config() + self._rtx_native_cudagraphs = ( + self.settings.cuda_graph_strategy != "disabled" + ) + self.context = self._create_execution_context() if self._has_nccl_ops: @@ -442,6 +488,10 @@ def _setup_engine(self) -> None: dtype._from(self.cuda_engine.get_tensor_dtype(output_name)).to(torch.dtype) for output_name in self.out_binding_names ] + self.input_shapes = [ + self.cuda_engine.get_tensor_shape(input_name) + for input_name in self.in_binding_names + ] self.output_shapes = [ self.cuda_engine.get_tensor_shape(output_name) for output_name in self.out_binding_names @@ -453,6 +503,118 @@ def _setup_engine(self) -> None: if self.requires_output_allocator: self.create_output_allocator() + # --- TensorRT-RTX --- + + def _setup_runtime_config(self) -> None: + """Build an ``IRuntimeConfig`` with runtime cache and dynamic-shape strategy. + + The runtime cache stores JIT compilation results so kernel/graph + compilation is not repeated across inference runs. The dynamic-shape + kernel specialization strategy controls how the JIT compiles + shape-specialized kernels (``lazy``, ``eager``, or ``none``). + """ + self.runtime_config = self.cuda_engine.create_runtime_config() + alloc_strategy = ( + trt.ExecutionContextAllocationStrategy.USER_MANAGED + if self.resource_allocation_strategy + else trt.ExecutionContextAllocationStrategy.STATIC + ) + self.runtime_config.set_execution_context_allocation_strategy(alloc_strategy) + self.runtime_config.dynamic_shapes_kernel_specialization_strategy = ( + _get_dynamic_shapes_kernel_strategy( + self.settings.dynamic_shapes_kernel_specialization_strategy + ) + ) + logger.info( + "Dynamic shapes kernel specialization strategy: " + f"{self.settings.dynamic_shapes_kernel_specialization_strategy}" + ) + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + self.settings.cuda_graph_strategy + ) + logger.info(f"CUDA graph strategy: {self.settings.cuda_graph_strategy}") + self.runtime_cache = self.runtime_config.create_runtime_cache() + self._load_runtime_cache() + self.runtime_config.set_runtime_cache(self.runtime_cache) + logger.info("TensorRT-RTX runtime cache configured") + + def _load_runtime_cache(self) -> None: + """Load runtime cache from disk if it exists (with a shared file lock).""" + if self.runtime_cache is None: + return + cache_path = self.settings.runtime_cache_path + if not os.path.isfile(cache_path): + logger.debug(f"No existing runtime cache at {cache_path}") + return + try: + from filelock import FileLock + + lock = FileLock(cache_path + ".lock") + with lock.acquire(timeout=10): + with open(cache_path, "rb") as f: + data = f.read() + if data: + self.runtime_cache.deserialize(data) + logger.info(f"Loaded runtime cache from {cache_path}") + except Exception as e: + logger.warning(f"Failed to load runtime cache: {e}") + + def _save_runtime_cache(self) -> None: + """Save runtime cache to disk (with an exclusive file lock).""" + if self.runtime_cache is None: + return + try: + host_mem = self.runtime_cache.serialize() + if host_mem is None: + return + cache_path = self.settings.runtime_cache_path + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + + from filelock import FileLock + + lock = FileLock(cache_path + ".lock") + with lock.acquire(timeout=10): + with open(cache_path, "wb") as f: + f.write(memoryview(host_mem)) + logger.info(f"Saved runtime cache to {cache_path}") + except Exception as e: + logger.warning(f"Failed to save runtime cache: {e}") + + def _is_monolithic_capturable(self, stream: torch.cuda.Stream) -> bool: + """Return True iff manual ``torch.cuda.CUDAGraph`` capture is safe. + + On RTX, unsafe when the TRT-RTX context is not stream-capturable, or + when ``"lazy"`` kernel specialization can still fire (dynamic inputs). + """ + if not ENABLED_FEATURES.tensorrt_rtx: + return True + has_dynamic_input = any(DYNAMIC_DIM in shape for shape in self.input_shapes) + not_capturable = ( + not self.context.is_stream_capturable(stream.cuda_stream), + ( + self.settings.dynamic_shapes_kernel_specialization_strategy == "lazy" + and has_dynamic_input + ), + ) + return not any(not_capturable) + + def _enable_rtx_native_cudagraphs(self) -> None: + """Switch this engine to TRT-RTX native CUDA graphs. + + Sets the runtime config's ``cuda_graph_strategy`` to + ``WHOLE_GRAPH_CAPTURE`` and rebuilds the execution context so it + picks up the new strategy. No-op on non-RTX or when the runtime + config is not present. + """ + if self.runtime_config is None: + return + self.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( + "whole_graph_capture" + ) + self.context = self._create_execution_context() + self._rtx_native_cudagraphs = True + logger.info("Switched to TRT-RTX native CUDA graphs") + # --- distributed / NCCL --- @property @@ -726,6 +888,26 @@ def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + if ( + ENABLED_FEATURES.tensorrt_rtx + and cudagraphs_enabled + and not self._rtx_native_cudagraphs + ): + logger.warning( + "Manual CUDA graph capture is not guaranteed to work on " + "TRT-RTX (lazy kernel specialization or non-capturable " + "stream). Switching to TRT-RTX native CUDA graphs. Set " + 'cuda_graph_strategy="whole_graph_capture" at compile ' + "time to avoid this warning." + ) + self._enable_rtx_native_cudagraphs() + + # When RTX native is active, TRT-RTX handles capture/replay + # internally so the manual ``torch.cuda.CUDAGraph`` machinery is + # skipped. + effective_cudagraphs = cudagraphs_enabled and not self._rtx_native_cudagraphs + # Pick the engine stream BEFORE set_runtime_states so that any # stream-identity change observed this call flips # runtime_states.context_changed in time to trigger same-call @@ -738,7 +920,7 @@ def _execute_standard( can_use_pre_allocated_outputs, need_cudagraphs_reset, ) = self.runtime_states.set_runtime_states( - torch_tensorrt.runtime.get_cudagraphs_mode(), + effective_cudagraphs, self.use_pre_allocated_outputs, shape_changed, ) @@ -753,7 +935,7 @@ def _execute_standard( with self._profile_section("TRTEngine:ProcessInputs"): self.setup_input_tensors( contiguous_inputs, - torch_tensorrt.runtime.get_cudagraphs_mode(), + effective_cudagraphs, need_cudagraphs_record, ) if shape_changed: @@ -780,7 +962,7 @@ def _execute_standard( for o, output_name in enumerate(self.out_binding_names): if need_cudagraphs_record: self._output_buffers[o] = outputs[o].clone() - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: self.context.set_tensor_address( output_name, self._output_buffers[o].data_ptr() ) @@ -799,7 +981,7 @@ def _execute_standard( ) self.context.set_device_memory(self._dynamic_workspace.data_ptr()) - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() if self._profile_execution: @@ -828,7 +1010,7 @@ def _execute_standard( ): self.pre_allocated_outputs = self.create_output_tensors() - if torch_tensorrt.runtime.get_cudagraphs_mode(): + if effective_cudagraphs: for idx, output in enumerate(outputs): output.copy_(self._output_buffers[idx]) @@ -929,8 +1111,13 @@ def execute( logger.debug("Using the dynamic allocator runtime mode.") return self._execute_output_allocator(contiguous_inputs) + effective_cudagraphs = ( + torch_tensorrt.runtime.get_cudagraphs_mode() + and not self._rtx_native_cudagraphs + ) logger.debug( - f"Using the standard execution runtime mode with cudagraphs={torch_tensorrt.runtime.get_cudagraphs_mode()}." + f"Using the standard execution runtime mode with cudagraphs={effective_cudagraphs}" + + (" (RTX native)" if self._rtx_native_cudagraphs else "") ) return self._execute_standard(contiguous_inputs) diff --git a/tests/py/dynamo/models/test_cuda_graph_strategy_models.py b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py new file mode 100644 index 0000000000..bce596d15f --- /dev/null +++ b/tests/py/dynamo/models/test_cuda_graph_strategy_models.py @@ -0,0 +1,186 @@ +import unittest + +import torch +import torch.nn.functional as F +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES + + +class ConvModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x): + return F.relu(self.conv(x)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyModels(TestCase): + """End-to-end model tests with cuda_graph_strategy.""" + + def _check_cosine_similarity(self, output, ref_output, threshold=0.99): + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > threshold, + f"Cosine similarity {cos_sim.item():.4f} below threshold {threshold}", + ) + + def test_resnet18_whole_graph_capture(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + def test_resnet18_disabled_strategy(self): + try: + from torchvision.models import resnet18 + except ImportError: + self.skipTest("torchvision not available") + + model = resnet18(weights=None).eval().cuda() + input_tensor = torch.randn(4, 3, 224, 224).cuda() + ref_output = model(input_tensor) + + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(4, 3, 224, 224), + max_shape=(8, 3, 224, 224), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="disabled", + ) + torch._dynamo.reset() + + output = compiled(input_tensor) + self._check_cosine_similarity(output, ref_output) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy models require TensorRT-RTX", +) +class TestCudaGraphStrategyDynamic(TestCase): + """Tests with dynamic batch sizes and cudagraph mode integration.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_dynamic_batch_whole_graph_capture(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + def test_dynamic_batch_with_subgraph_cudagraphs(self): + model = ConvModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3, 32, 32), + opt_shape=(4, 3, 32, 32), + max_shape=(8, 3, 32, 32), + dtype=torch.float32, + ) + ] + compiled = torchtrt.compile( + model, + ir="dynamo", + inputs=inputs, + enabled_precisions={torch.float32}, + use_python_runtime=True, + min_block_size=1, + cuda_graph_strategy="whole_graph_capture", + ) + torch._dynamo.reset() + + torchtrt.runtime.set_cudagraphs_mode(True) + + for bs in (1, 4, 8): + input_tensor = torch.randn(bs, 3, 32, 32).cuda() + ref_output = model(input_tensor) + output = compiled(input_tensor) + cos_sim = F.cosine_similarity( + output.flatten().unsqueeze(0), + ref_output.flatten().unsqueeze(0), + ) + self.assertTrue( + cos_sim.item() > 0.99, + f"Batch size {bs}: cosine similarity {cos_sim.item():.4f} too low", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index 14700cbd14..05637a6146 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -73,10 +73,10 @@ def test_runtime_config_created(self): engine = _find_python_trt_engine(compiled) self.assertIsNotNone(engine, "No Python TRTEngine found in compiled model") self.assertIsNotNone( - engine._runtime_config, "runtime_config should be set for RTX" + engine.runtime_config, "runtime_config should be set for RTX" ) self.assertIsNotNone( - engine._runtime_cache, "runtime_cache should be set for RTX" + engine.runtime_cache, "runtime_cache should be set for RTX" ) def test_context_created_successfully(self): @@ -90,7 +90,7 @@ def test_context_created_successfully(self): def test_runtime_cache_path_default(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.runtime_cache_path, RUNTIME_CACHE_PATH) + self.assertEqual(engine.settings.runtime_cache_path, RUNTIME_CACHE_PATH) def test_runtime_cache_path_custom(self): cache_dir = tempfile.mkdtemp() @@ -98,7 +98,7 @@ def test_runtime_cache_path_custom(self): custom_path = os.path.join(cache_dir, "my_cache.bin") compiled, _ = _compile_simple(runtime_cache_path=custom_path) engine = _find_python_trt_engine(compiled) - self.assertEqual(engine.runtime_cache_path, custom_path) + self.assertEqual(engine.settings.runtime_cache_path, custom_path) finally: shutil.rmtree(cache_dir, ignore_errors=True) @@ -268,15 +268,15 @@ def test_no_runtime_config_for_standard_trt(self): compiled, _ = _compile_simple() engine = _find_python_trt_engine(compiled) if engine is not None: - # The TRT-RTX runtime cache machinery is exposed via the private - # ``_runtime_config``/``runtime_cache`` attributes on the Python + # The TRT-RTX runtime cache machinery is exposed via the + # ``runtime_config`` / ``runtime_cache`` attributes on the Python # engine. On non-RTX builds neither should be populated. self.assertIsNone( - engine._runtime_config, + engine.runtime_config, "runtime_config should be None for standard TRT", ) self.assertIsNone( - engine._runtime_cache, + engine.runtime_cache, "runtime_cache should be None for standard TRT", ) diff --git a/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py new file mode 100644 index 0000000000..8b534be362 --- /dev/null +++ b/tests/py/dynamo/runtime/test_001_cuda_graph_strategy.py @@ -0,0 +1,357 @@ +import unittest + +import torch +import torch_tensorrt as torchtrt +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES +from torch_tensorrt.dynamo._settings import CompilationSettings + + +class SimpleModel(torch.nn.Module): + def forward(self, x): + return torch.relu(x) + 1.0 + + +def _compile_simple(**extra_kwargs): + """Helper: compile SimpleModel with dynamic shapes and Python runtime.""" + model = SimpleModel().eval().cuda() + inputs = [ + torchtrt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(4, 3), + dtype=torch.float32, + ) + ] + kwargs = { + "ir": "dynamo", + "inputs": inputs, + "enabled_precisions": {torch.float32}, + "use_python_runtime": True, + "min_block_size": 1, + } + kwargs.update(extra_kwargs) + compiled = torchtrt.compile(model, **kwargs) + torch._dynamo.reset() + return compiled + + +def _find_python_trt_engine(compiled): + """Walk the compiled graph module and return the Python ``TRTEngine`` instance. + + The C++ and Python runtimes are now both driven through ``TorchTensorRTModule``; + ``use_python_runtime=True`` causes ``module.engine`` to be a Python ``TRTEngine``. + """ + from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + from torch_tensorrt.dynamo.runtime._TRTEngine import TRTEngine + + for _, mod in compiled.named_modules(): + if isinstance(mod, TorchTensorRTModule) and isinstance(mod.engine, TRTEngine): + return mod.engine + return None + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy requires TensorRT-RTX", +) +class TestCudaGraphStrategySetup(TestCase): + """Tests that cuda_graph_strategy is correctly applied on TRT-RTX.""" + + def test_default_strategy_is_disabled(self): + import tensorrt as trt + + compiled = _compile_simple() + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine, "No Python TRTEngine found") + self.assertIsNotNone( + engine.runtime_config, "runtime_config should be set for RTX" + ) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.DISABLED, + ) + + def test_whole_graph_capture_strategy(self): + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_rtx_native_flag_set(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertTrue(engine._rtx_native_cudagraphs) + + def test_rtx_native_flag_disabled(self): + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertFalse(engine._rtx_native_cudagraphs) + + def test_inference_with_each_strategy(self): + for strategy in ("disabled", "whole_graph_capture"): + with self.subTest(strategy=strategy): + compiled = _compile_simple(cuda_graph_strategy=strategy) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone( + engine.context, + f"Execution context should be created for {strategy}", + ) + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_setting_in_compilation_settings(self): + for strategy in ("disabled", "whole_graph_capture"): + settings = CompilationSettings(cuda_graph_strategy=strategy) + self.assertEqual(settings.cuda_graph_strategy, strategy) + + def test_default_compilation_settings(self): + settings = CompilationSettings() + self.assertEqual(settings.cuda_graph_strategy, "disabled") + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "CUDA graph strategy integration requires TensorRT-RTX", +) +class TestCudaGraphStrategyWithSubgraphCudagraphs(TestCase): + """Tests integration with set_cudagraphs_mode().""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_rtx_native_bypasses_manual_capture(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference a few times to ensure capture would have happened + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Manual cudagraph should NOT have been recorded (RTX handles it natively) + self.assertFalse( + isinstance(engine.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded when RTX native is active", + ) + + def test_subgraph_mode_always_uses_rtx_native(self): + """Even with cuda_graph_strategy=disabled, SUBGRAPH mode on RTX + should override to RTX-native because manual capture is not safe.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + # Initially, _rtx_native_cudagraphs is False (disabled strategy) + self.assertFalse(engine._rtx_native_cudagraphs) + + torchtrt.runtime.set_cudagraphs_mode(True) + + # Run inference -- should trigger override to RTX-native + for _ in range(3): + compiled(torch.randn(2, 3).cuda()) + + # Should have been overridden to RTX-native + self.assertTrue( + engine._rtx_native_cudagraphs, + "RTX-native should be enabled automatically in SUBGRAPH mode", + ) + # Manual cudagraph should NOT have been recorded + self.assertFalse( + isinstance(engine.cudagraph, torch.cuda.CUDAGraph), + "Manual CUDA graph should not be recorded on RTX", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Monolithic capturability tests require TensorRT-RTX", +) +class TestMonolithicCapturability(TestCase): + """Tests for _is_monolithic_capturable() and related logic.""" + + def test_lazy_strategy_not_monolithic_capturable(self): + """Lazy kernel specialization makes monolithic capture unsafe.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="lazy", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + self.assertFalse(engine._is_monolithic_capturable(stream)) + + def test_eager_strategy_monolithic_capturable(self): + """Eager strategy with capturable stream should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="eager", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + # is_stream_capturable depends on engine properties. + # With eager strategy, the strategy check passes. + if engine.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(engine._is_monolithic_capturable(stream)) + + def test_none_strategy_monolithic_capturable(self): + """None strategy (always fallback) should be monolithic capturable.""" + compiled = _compile_simple( + cuda_graph_strategy="disabled", + dynamic_shapes_kernel_specialization_strategy="none", + ) + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + stream = torch.cuda.Stream() + if engine.context.is_stream_capturable(stream.cuda_stream): + self.assertTrue(engine._is_monolithic_capturable(stream)) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Context recreation tests require TensorRT-RTX", +) +class TestContextRecreation(TestCase): + """Tests for _enable_rtx_native_cudagraphs() context recreation.""" + + def test_enable_rtx_native_recreates_context(self): + """Calling _enable_rtx_native_cudagraphs recreates the execution context.""" + import tensorrt as trt + + compiled = _compile_simple(cuda_graph_strategy="disabled") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertFalse(engine._rtx_native_cudagraphs) + + old_context_id = id(engine.context) + engine._enable_rtx_native_cudagraphs() + + self.assertTrue(engine._rtx_native_cudagraphs) + self.assertNotEqual( + id(engine.context), + old_context_id, + "Context should be recreated", + ) + self.assertEqual( + engine.runtime_config.cuda_graph_strategy, + trt.CudaGraphStrategy.WHOLE_GRAPH_CAPTURE, + ) + + def test_explicit_whole_graph_capture_no_override_needed(self): + """With explicit whole_graph_capture, SUBGRAPH mode should not + need to override (already RTX-native).""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertIsNotNone(engine) + self.assertTrue(engine._rtx_native_cudagraphs) + + old_context_id = id(engine.context) + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) + torchtrt.runtime.set_cudagraphs_mode(False) + + # Context should NOT have been recreated (was already RTX-native) + self.assertEqual( + id(engine.context), + old_context_id, + "Context should not be recreated if already RTX-native", + ) + + +@unittest.skipIf( + not ENABLED_FEATURES.tensorrt_rtx, + "Cudagraph mode toggle tests require TensorRT-RTX", +) +class TestCudagraphModeToggle(TestCase): + """Tests for toggling cudagraph mode with RTX-native.""" + + def setUp(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def tearDown(self): + torchtrt.runtime.set_cudagraphs_mode(False) + + def test_cudagraphs_off_after_rtx_native_override(self): + """After RTX-native override, disabling cudagraphs should still + produce correct results (RTX-native continues transparently).""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + + torchtrt.runtime.set_cudagraphs_mode(True) + compiled(torch.randn(2, 3).cuda()) # triggers override + + torchtrt.runtime.set_cudagraphs_mode(False) + + # Should still work -- RTX-native is transparent + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_no_cudagraphs_with_whole_graph_capture(self): + """With cuda_graph_strategy='whole_graph_capture' but no + set_cudagraphs_mode, RTX-native runs transparently.""" + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + self.assertTrue(engine._rtx_native_cudagraphs) + + # No set_cudagraphs_mode(True) -- RTX-native still active transparently + for bs in (1, 2, 4): + output = compiled(torch.randn(bs, 3).cuda()) + self.assertEqual(output.shape, (bs, 3)) + + def test_toggle_on_off_on(self): + """Toggle cudagraphs on -> off -> on, verify correctness each time.""" + compiled = _compile_simple(cuda_graph_strategy="disabled") + inp = torch.randn(2, 3).cuda() + + # Phase 1: on + torchtrt.runtime.set_cudagraphs_mode(True) + out1 = compiled(inp) + self.assertEqual(out1.shape, (2, 3)) + + # Phase 2: off + torchtrt.runtime.set_cudagraphs_mode(False) + out2 = compiled(inp) + self.assertEqual(out2.shape, (2, 3)) + + # Phase 3: on again + torchtrt.runtime.set_cudagraphs_mode(True) + out3 = compiled(inp) + self.assertEqual(out3.shape, (2, 3)) + + +@unittest.skipIf( + ENABLED_FEATURES.tensorrt_rtx, + "This test verifies standard TRT behavior (non-RTX)", +) +class TestCudaGraphStrategyNonRTX(TestCase): + """Tests that the setting is ignored on non-RTX builds.""" + + def test_setting_ignored_on_non_rtx(self): + compiled = _compile_simple(cuda_graph_strategy="whole_graph_capture") + engine = _find_python_trt_engine(compiled) + if engine is not None: + self.assertIsNone( + engine.runtime_config, + "runtime_config should be None for standard TRT", + ) + self.assertFalse(engine._rtx_native_cudagraphs) + output = compiled(torch.randn(2, 3).cuda()) + self.assertEqual(output.shape, (2, 3)) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py index 359d6bbc9d..11998f9b7a 100644 --- a/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py +++ b/tests/py/dynamo/runtime/test_001_dynamic_shapes_kernel_strategy.py @@ -139,7 +139,7 @@ def test_setting_ignored_on_non_rtx(self): engine = _find_python_trt_engine(compiled) if engine is not None: self.assertIsNone( - getattr(engine, "_runtime_config", None), + engine.runtime_config, "runtime_config should be None for standard TRT", ) # Inference should still work diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 5b8b0b94b6..0e82cd4613 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._utils import is_orin from torch_tensorrt.dynamo.utils import prepare_inputs @@ -215,6 +216,11 @@ def test_weight_streaming_multi_rt(self): torch._dynamo.reset() def test_weight_streaming_cudagraphs(self): + if ENABLED_FEATURES.tensorrt_rtx: + self.skipTest( + "Manual whole-graph CUDA graph capture (enable_cudagraphs) is " + "incompatible with weight streaming on TRT-RTX." + ) model = SampleModel().eval().cuda() input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()] exp_program = torch.export.export(model, tuple(input)) @@ -260,6 +266,12 @@ def test_weight_streaming_cudagraphs(self): is_orin(), "There is a bug on Orin platform, skip for now until bug is fixed" ) def test_runtime_state_change(self): + if ENABLED_FEATURES.tensorrt_rtx: + self.skipTest( + "Manual whole-graph CUDA graph capture (enable_cudagraphs) is " + "incompatible with weight streaming on TRT-RTX." + ) + class SampleModel(torch.nn.Module): def __init__(self): super().__init__()