-
Notifications
You must be signed in to change notification settings - Fork 405
feat: reintroduce TRT-RTX runtime cache, dynamic shapes, and native CUDA graph support #4294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5db92e5
f691f16
378c1bd
a683a6f
53470af
b2e37c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,6 +127,59 @@ def __del__(self) -> None: | |
| def set_use_output_allocator(self, enable: bool) -> None: | ||
| self.use_output_allocator_outputs = enable | ||
|
|
||
| def _check_monolithic_capturability(self, stream: torch.cuda.Stream) -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whats the difference between monolithic capture and the "whole graph capture mode"?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi Naren, thanks for the Q. What I meant is the following:
When to use whatThese two paths are mutually exclusive, even when running through The guidance would be to prefer (2) whole graph capture, as this would get best perf (per an TRT-RTX engine). However the condition here is that there are no intervening pytorch ops (as this will break TRT-RTX's internal graph capture status), which is consistent with it getting triggered for However, in case there are PyTorch ops in between (because of graph breaks, op incompleteness), the Name sharpeningNow, naming-wise, "monolithic" was picked to avoid the literal collision with "whole_graph_capture" since the latter is a TRT-RTX-defined enum name but also confusing with Torch-TRT's Perhaps a good middle ground is to rename the enum in the newly introduced P.S. I will also add these best practices/guidance to the documentation. Also the behavior matrix (of how cudagraph mode interacts with TRT-RTX) is documented more in the PR description. |
||
| """Verify every TRT submodule is safe for monolithic stream capture. | ||
|
|
||
| Whole-graph CUDA graph mode wraps mixed TRT + PyTorch ops in a | ||
| single outer ``torch.cuda.CUDAGraph`` capture. On TRT-RTX, each | ||
| engine must opt out of RTX-native CUDA graphs (which would | ||
| interfere with the outer capture) and must pass the | ||
| ``IExecutionContext.is_stream_capturable`` check. Raises | ||
| ``RuntimeError`` if any TRT engine is not monolithically | ||
| capturable. No-op on non-RTX builds. | ||
| """ | ||
| from torch_tensorrt._features import ENABLED_FEATURES | ||
|
|
||
| if not ENABLED_FEATURES.tensorrt_rtx: | ||
| return | ||
| from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( | ||
| TorchTensorRTModule, | ||
| ) | ||
| from torch_tensorrt.dynamo.runtime._TRTEngine import ( | ||
| TRTEngine, | ||
| _get_cuda_graph_strategy, | ||
| ) | ||
|
|
||
| for name, mod in self.compiled_module.named_modules(): | ||
| if not ( | ||
| isinstance(mod, TorchTensorRTModule) | ||
| and isinstance(mod.engine, TRTEngine) | ||
| ): | ||
| continue | ||
| engine = mod.engine | ||
| if not engine._is_monolithic_capturable(stream): | ||
| raise RuntimeError( | ||
| f"CUDA graph capture failed: TRT submodule '{name}' is " | ||
| "not monolithically capturable (lazy kernel " | ||
| "specialization or non-capturable stream). Whole-graph " | ||
| "CUDA graph mode with mixed TRT + PyTorch ops requires " | ||
| "all TRT engines to be capturable. Consider using " | ||
| "cuda_graph_strategy='whole_graph_capture' with " | ||
| "set_cudagraphs_mode(True) instead of enable_cudagraphs()." | ||
| ) | ||
| # Disable RTX-native CUDA graphs on this engine so they don't | ||
| # interfere with the outer monolithic capture. | ||
| if engine._rtx_native_cudagraphs: | ||
| engine.runtime_config.cuda_graph_strategy = _get_cuda_graph_strategy( | ||
| "disabled" | ||
| ) | ||
| engine.context = engine._create_execution_context() | ||
| engine._rtx_native_cudagraphs = False | ||
| logger.info( | ||
| f"Disabled RTX-native CUDA graphs for '{name}' " | ||
| "(using outer monolithic capture instead)" | ||
| ) | ||
|
|
||
| def forward( | ||
| self, *args: Any, **kwargs: Any | ||
| ) -> torch.Tensor | Tuple[torch.Tensor, ...]: | ||
|
|
@@ -212,6 +265,7 @@ def forward( | |
|
|
||
| with torch.cuda.stream(self._engine_stream): | ||
| if need_cudagraphs_record: | ||
| self._check_monolithic_capturability(self._engine_stream) | ||
| self.cudagraph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): | ||
| self._output_buffers = self.compiled_module(*args, **kwargs) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan Should we add a guard to remind the user that this is RTX only? It is likely to be mixed up by users
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this was a followup (I have a task filed) after the C++ changes landed, but unfortunately we never got to that point. I have the new changes ported over at tp5uiuc#2 and I will make the PR after this one merges in.
We are also making these strategies context managers (as is the recommended approach, see) #4310. Given these changes, I will leave this to a future MR.