Skip to content

MD-TRT Support, Compile/Export, C++ and Python #4183

Open
narendasan wants to merge 13 commits intomainfrom
push-vqqzkszwrvyx
Open

MD-TRT Support, Compile/Export, C++ and Python #4183
narendasan wants to merge 13 commits intomainfrom
push-vqqzkszwrvyx

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

Description

Opening this to test the CI

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

apbose and others added 11 commits April 12, 2026 11:41
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8
- Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection
- Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback
- Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops
- Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain
- Build: conditional NCCL compilation via torch_nccl toolchain
- Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
…hapes

Five interconnected fixes:

1. fold_get_attr_item_calls: fold scalar param .item() calls into Python
   scalars before AOT tracing. Inside FakeTensorMode, even real-tensor
   .item() calls raise DataDependentOutputException.

2. backends.py: three changes:
   - call fold_get_attr_item_calls before entering FakeTensorMode
   - detect vmap/higher-order ops and route them through aot_autograd
     instead of aot_export_joint_simple (which doesn't handle HOPs)
   - on TRT build failure, strip TRT-only kwargs (use_fp32_acc) from
     the fallback graph before returning it to PyTorch

3. _decompositions.py: prevent SDPA from leaking back into the decomp
   table via Core ATen Interchange ops even after being removed from
   TORCH_TRT_DECOMPOSITIONS.

4. partitioning/common.py: lower the default max dynamic shape from
   min*2^16 to min*2^12 — 65536 is too large for TRT to find kernel
   implementations for attention ops.

5. _TorchTensorRTModule.py: move CPU scalar inputs to CUDA before
   execution — aot_autograd lifts scalar attributes (e.g. head_dim^-0.5)
   as explicit graph inputs; TRT requires all inputs on CUDA.

Also fixes remove_sym_nodes to match tensor sources by equality rather
than local_name so that GetItemSource bases (from torch.compile
dynamic=True) are matched correctly, and updates register_sdpa.py to
handle aten.scaled_dot_product_attention.default (the form produced after
aot_autograd) in addition to the flash/efficient variants.
@meta-cla meta-cla bot added the cla signed label Apr 12, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Apr 12, 2026
@github-actions github-actions bot requested a review from zewenli98 April 12, 2026 19:09
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the push-vqqzkszwrvyx branch 5 times, most recently from 473cff9 to 9022e03 Compare April 13, 2026 01:14

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// All inputs are expected to be on CUDA. Warn and move any that are not.
for (auto& inp : inputs) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to remove this but didnt have time to check if the device operations in python suppress this correctly

// the constructor-time bind was deferred (e.g. no collective had been issued
// at construction time, or for serialized programs loaded inline where there
// is no Python _TorchTensorRTModule.forward wrapper).
if (compiled_engine->is_md && !compiled_engine->nccl_initialized) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure this is necessary

.def(
"set_group_name",
[](c10::intrusive_ptr<TRTEngine> self, std::string group_name) {
self->group_name = group_name;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a test where we switch engines between process groups

// process group from the c10d registry. PyTorch assigns sequential
// numeric names ("0", "1", ...) to process groups; probe until we
// find one with an NCCL backend.
if (this->group_name.empty() && this->is_md) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select


def forward(self, x):
out = self.linear(x)
out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets dig into this more after the PR lands

logger = logging.getLogger("torchtrtrun")


def _get_nccl_lib_dir() -> Optional[str]:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move into its own file


self._nccl_comm: Optional[Any] = None
self._has_nccl_ops: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be set before the self.setup_engine()

inspector = self.engine.create_engine_inspector()
engine_json = inspector.get_engine_information(trt.LayerInformationFormat.JSON)
self._has_nccl_ops = "NCCL" in engine_json or "AllReduce" in engine_json

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this works
engine_json_lower = engine_json.lower()
self._has_nccl_ops = "dist_collective" in engine_json_lower or "nccl" in engine_json_lower or "allreduce" in engine_json_lower

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests component: torch_compile documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants