feat: add TRT-RTX native CUDA graph support#4187
Draft
tp5uiuc wants to merge 5 commits intopytorch:mainfrom
Draft
feat: add TRT-RTX native CUDA graph support#4187tp5uiuc wants to merge 5 commits intopytorch:mainfrom
tp5uiuc wants to merge 5 commits intopytorch:mainfrom
Conversation
Add runtime cache support for TensorRT-RTX JIT compilation results, replacing the timing cache which is not used by RTX (no autotuning). Changes: - Skip timing cache creation/saving for TensorRT-RTX in _TRTInterpreter - Add RUNTIME_CACHE_PATH default and runtime_cache_path setting - Wire up IRuntimeCache in PythonTorchTensorRTModule (setup, load, save) - Persist runtime cache to disk with filelock for concurrent access safety - Thread runtime_cache_path through all compile functions - Add unit tests (12 tests) and E2E model tests (6 tests) - Update docstrings and RST documentation Fixes pytorch#3817 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Version provided by upstream torch; no pin needed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Expose IRuntimeConfig.setDynamicShapesKernelSpecializationStrategy()
through the Torch-TensorRT Python API. Users can now control how
shape-specialized kernels are compiled at runtime for dynamic shapes
on TensorRT-RTX via the new `dynamic_shapes_kernel_specialization_strategy`
compilation setting ("lazy", "eager", or "none").
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address review feedback: compile with torchtrt.Input min/opt/max ranges so dynamic shapes are actually exercised. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add cuda_graph_strategy compilation setting and automatic RTX-native
CUDA graph integration for the Python runtime path.
Key changes:
- New cuda_graph_strategy setting ("disabled" / "whole_graph_capture")
on CompilationSettings, mapped to trt.CudaGraphStrategy on
IRuntimeConfig (same pattern as dynamic_shapes_kernel_specialization)
- In SUBGRAPH cudagraph mode on RTX, always use RTX-native CUDA graphs
(manual torch.cuda.CUDAGraph capture is not safe due to lazy kernel
specialization and potential runtime allocation)
- _is_monolithic_capturable() check using context.is_stream_capturable()
and strategy != "lazy" for WHOLE_GRAPH mode safety validation
- _enable_rtx_native_cudagraphs() for runtime context recreation
- _check_monolithic_capturability() in CudaGraphsTorchTensorRTModule
for mixed TRT + PyTorch graph validation
- Comprehensive unit tests covering all code paths
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add
cuda_graph_strategycompilation setting and automatic RTX-native CUDA graph integration for the Python runtime path (PythonTorchTensorRTModule).TensorRT-RTX has native CUDA graph support via
IRuntimeConfig.cuda_graph_strategy, where the JIT compiler handles capture/replay/invalidation internally. This is superior to manualtorch.cuda.CUDAGraph()capture on RTX because:cudaStreamBeginCaptureto failKey changes
cuda_graph_strategysetting onCompilationSettings("disabled"/"whole_graph_capture")trt.CudaGraphStrategyonIRuntimeConfig(same pattern asdynamic_shapes_kernel_specialization_strategy)set_cudagraphs_mode(True)): On RTX, always use RTX-native CUDA graphs — manual capture is bypassed. Ifcuda_graph_strategywas not explicitly set, the runtime overrides towhole_graph_captureand warns.enable_cudagraphs()with mixed TRT + PyTorch ops): Validates all TRT engines are monolithically capturable viacontext.is_stream_capturable(stream)andstrategy != "lazy". If capturable, proceeds with outer monolithic capture (RTX-native disabled per-engine). If not capturable, raisesRuntimeError._is_monolithic_capturable()— runtime check combining stream capturability and kernel specialization strategy_enable_rtx_native_cudagraphs()— recreates execution context withWHOLE_GRAPH_CAPTURE_check_monolithic_capturability()inCudaGraphsTorchTensorRTModulefor mixed graph validationBehavior matrix
Depends on #4180 (runtime cache) and #4184 (dynamic shapes strategy).
Type of change
Checklist: