Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def cross_compile_for_windows(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -171,6 +172,7 @@ def cross_compile_for_windows(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -318,6 +320,7 @@ def cross_compile_for_windows(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -432,6 +435,7 @@ def compile(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -527,6 +531,7 @@ def compile(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -706,6 +711,7 @@ def compile(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down Expand Up @@ -1226,6 +1232,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
runtime_cache_path: str = _defaults.RUNTIME_CACHE_PATH,
dynamic_shapes_kernel_specialization_strategy: str = _defaults.DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY,
cuda_graph_strategy: str = _defaults.CUDA_GRAPH_STRATEGY,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
Expand Down Expand Up @@ -1302,6 +1309,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX.
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for dynamic shape kernel specialization at runtime (TensorRT-RTX only). Options: "lazy", "eager", "none". Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (manual capture), "whole_graph_capture" (TRT-RTX handles internally). Default: "disabled".
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan Should we add a guard to remind the user that this is RTX only? It is likely to be mixed up by users

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a followup (I have a task filed) after the C++ changes landed, but unfortunately we never got to that point. I have the new changes ported over at tp5uiuc#2 and I will make the PR after this one merges in.

We are also making these strategies context managers (as is the recommended approach, see) #4310. Given these changes, I will leave this to a future MR.

lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
Expand Down Expand Up @@ -1458,6 +1466,7 @@ def convert_exported_program_to_serialized_trt_engine(
"timing_cache_path": timing_cache_path,
"runtime_cache_path": runtime_cache_path,
"dynamic_shapes_kernel_specialization_strategy": dynamic_shapes_kernel_specialization_strategy,
"cuda_graph_strategy": cuda_graph_strategy,
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"
CUDA_GRAPH_STRATEGY = "disabled"
USE_PYTHON_RUNTIME = False

if platform.system() == "Linux":
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AUTOCAST_MAX_OUTPUT_THRESHOLD,
CACHE_BUILT_ENGINES,
CPU_MEMORY_BUDGET,
CUDA_GRAPH_STRATEGY,
DECOMPOSE_ATTENTION,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
Expand Down Expand Up @@ -95,6 +96,7 @@ class CompilationSettings:
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation. Not used for TensorRT-RTX (no autotuning).
runtime_cache_path (str): Path to the runtime cache for TensorRT-RTX JIT compilation results. The cache is loaded on engine setup and saved on module cleanup. Uses file locking for concurrent access safety. Not used for standard TensorRT.
dynamic_shapes_kernel_specialization_strategy (str): Strategy for compiling shape-specialized kernels at runtime for dynamic shapes (TensorRT-RTX only). Options: "lazy" (compile in background, use fallback until ready), "eager" (compile immediately, blocking), "none" (always use fallback kernels). Default: "lazy".
cuda_graph_strategy (str): Strategy for CUDA graph capture/replay (TensorRT-RTX only). Options: "disabled" (no native CUDA graphs, uses manual capture if cudagraphs mode is enabled), "whole_graph_capture" (TRT-RTX handles CUDA graph capture internally). When set to "whole_graph_capture", the manual torch CUDA graph capture/replay in forward() is bypassed. Default: "disabled".
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32.
Expand Down Expand Up @@ -150,6 +152,7 @@ class CompilationSettings:
dynamic_shapes_kernel_specialization_strategy: str = (
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY
)
cuda_graph_strategy: str = CUDA_GRAPH_STRATEGY
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
Expand Down
54 changes: 54 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,59 @@ def __del__(self) -> None:
def set_use_output_allocator(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the difference between monolithic capture and the "whole graph capture mode"?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Naren, thanks for the Q. What I meant is the following:

  • Monolithic capture: This path is similar to standard TRT. It relies on the outer torch.cuda.CUDAGraph recorded by _CudaGraphsTorchTensorRTModule.forward (after the user calls
    torchtrt.runtime.enable_cudagraphs()) capturing TRT-RTX engines. Similar to standard TRT, it wraps the entire compiled subgraph (every TRT engine and any intervening PyTorch glue ops) into one monolithic torch-side graph. Scope = whole pytorch graph. This code-path gets triggered for CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS

  • Whole graph capture: This is a TRT-RTX exclusive path and obtained when setting cuda_graph_strategy="whole_graph_capture" per engine. This relies on TRT-RTX to manage captures : internally the JIT compiler captures/replays/invalidates the graph inside execute_async_v3, per-engine. The name "whole graph" is for the "whole" TRT-RTX forwarded graph. Scope = one TRT engine. This code-path gets triggered for CudaGraphsMode.SUBGRAPH_CUDAGRAPHS

When to use what

These two paths are mutually exclusive, even when running through enable_cuda_graphs().

The guidance would be to prefer (2) whole graph capture, as this would get best perf (per an TRT-RTX engine). However the condition here is that there are no intervening pytorch ops (as this will break TRT-RTX's internal graph capture status), which is consistent with it getting triggered for CudaGraphsMode.SUBGRAPH_CUDAGRAPHS.

However, in case there are PyTorch ops in between (because of graph breaks, op incompleteness), the enable_cuda_graphs() will choose the CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS mode. This means we will attempt to wrap N TRT-RTX engines + PyTorch ops in one outer torch.cuda.CUDAGraph and each engine has to leave its own (internal RTX-cudagraph) capture off, otherwise the RTX-native capture would interfere with the outer torch capture. That's what _check_monolithic_capturability enforces : it asserts every engine can be stream-captured (is_stream_capturable, no lazy-with-dynamic-shapes), and forces RTX-native off on each one (cuda_graph_strategy=DISABLED + context rebuild) before the outer torch.cuda.graph() block runs.

Name sharpening

Now, naming-wise, "monolithic" was picked to avoid the literal collision with "whole_graph_capture" since the latter is a TRT-RTX-defined enum name but also confusing with Torch-TRT's CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS. The difference lies in the distinction of "whole" for TRT-RTX vs torch-TRT. I am open to renaming either side if you have a preference.

Perhaps a good middle ground is to rename the enum in the newly introduced cuda_graph_strategy to disabled and per_engine_capture. per_engine_capture can internally map to TRT-RTX's whole_graph_capture mode. This is easier for the users at least IMO.

P.S. I will also add these best practices/guidance to the documentation. Also the behavior matrix (of how cudagraph mode interacts with TRT-RTX) is documented more in the PR description.

"""Verify every TRT submodule is safe for monolithic stream capture.

Whole-graph CUDA graph mode wraps mixed TRT + PyTorch ops in a
single outer ``torch.cuda.CUDAGraph`` capture. On TRT-RTX, each
engine must opt out of RTX-native CUDA graphs (which would
interfere with the outer capture) and must pass the
``IExecutionContext.is_stream_capturable`` check. Raises
``RuntimeError`` if any TRT engine is not monolithically
capturable. No-op on non-RTX builds.
"""
from torch_tensorrt._features import ENABLED_FEATURES

if not ENABLED_FEATURES.tensorrt_rtx:
return
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime._TRTEngine import (
TRTEngine,
_get_cuda_graph_strategy,
)

for name, mod in self.compiled_module.named_modules():
if not (
isinstance(mod, TorchTensorRTModule)
and isinstance(mod.engine, TRTEngine)
):
continue
engine = mod.engine
if not engine._is_monolithic_capturable(stream):
raise RuntimeError(
f"CUDA graph capture failed: TRT submodule '{name}' is "
"not monolithically capturable (lazy kernel "
"specialization or non-capturable stream). Whole-graph "
"CUDA graph mode with mixed TRT + PyTorch ops requires "
"all TRT engines to be capturable. Consider using "
"cuda_graph_strategy='whole_graph_capture' with "
"set_cudagraphs_mode(True) instead of enable_cudagraphs()."
)
# Disable RTX-native CUDA graphs on this engine so they don't
# interfere with the outer monolithic capture.
if engine._rtx_native_cudagraphs:
engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy(
"disabled"
)
engine.context = engine._create_execution_context()
engine._rtx_native_cudagraphs = False
logger.info(
f"Disabled RTX-native CUDA graphs for '{name}' "
"(using outer monolithic capture instead)"
)

def forward(
self, *args: Any, **kwargs: Any
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
Expand Down Expand Up @@ -212,6 +265,7 @@ def forward(

with torch.cuda.stream(self._engine_stream):
if need_cudagraphs_record:
self._check_monolithic_capturability(self._engine_stream)
self.cudagraph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cudagraph, stream=self._engine_stream):
self._output_buffers = self.compiled_module(*args, **kwargs)
Expand Down
Loading
Loading