MD-TRT Support, Compile/Export, C++ and Python #4183
MD-TRT Support, Compile/Export, C++ and Python #4183narendasan wants to merge 13 commits intomainfrom
Conversation
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8 - Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection - Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback - Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops - Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain - Build: conditional NCCL compilation via torch_nccl toolchain - Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
…g and enable DTensor decomposition
…hapes
Five interconnected fixes:
1. fold_get_attr_item_calls: fold scalar param .item() calls into Python
scalars before AOT tracing. Inside FakeTensorMode, even real-tensor
.item() calls raise DataDependentOutputException.
2. backends.py: three changes:
- call fold_get_attr_item_calls before entering FakeTensorMode
- detect vmap/higher-order ops and route them through aot_autograd
instead of aot_export_joint_simple (which doesn't handle HOPs)
- on TRT build failure, strip TRT-only kwargs (use_fp32_acc) from
the fallback graph before returning it to PyTorch
3. _decompositions.py: prevent SDPA from leaking back into the decomp
table via Core ATen Interchange ops even after being removed from
TORCH_TRT_DECOMPOSITIONS.
4. partitioning/common.py: lower the default max dynamic shape from
min*2^16 to min*2^12 — 65536 is too large for TRT to find kernel
implementations for attention ops.
5. _TorchTensorRTModule.py: move CPU scalar inputs to CUDA before
execution — aot_autograd lifts scalar attributes (e.g. head_dim^-0.5)
as explicit graph inputs; TRT requires all inputs on CUDA.
Also fixes remove_sym_nodes to match tensor sources by equality rather
than local_name so that GetItemSource bases (from torch.compile
dynamic=True) are matched correctly, and updates register_sdpa.py to
handle aten.scaled_dot_product_attention.default (the form produced after
aot_autograd) in addition to the flash/efficient variants.
67134da to
b5b1f5f
Compare
b5b1f5f to
1957cc4
Compare
473cff9 to
9022e03
Compare
9022e03 to
e08b0c5
Compare
|
|
||
| std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) { | ||
| // All inputs are expected to be on CUDA. Warn and move any that are not. | ||
| for (auto& inp : inputs) { |
There was a problem hiding this comment.
I would like to remove this but didnt have time to check if the device operations in python suppress this correctly
| // the constructor-time bind was deferred (e.g. no collective had been issued | ||
| // at construction time, or for serialized programs loaded inline where there | ||
| // is no Python _TorchTensorRTModule.forward wrapper). | ||
| if (compiled_engine->is_md && !compiled_engine->nccl_initialized) { |
There was a problem hiding this comment.
Not entirely sure this is necessary
| .def( | ||
| "set_group_name", | ||
| [](c10::intrusive_ptr<TRTEngine> self, std::string group_name) { | ||
| self->group_name = group_name; |
There was a problem hiding this comment.
We should have a test where we switch engines between process groups
| // process group from the c10d registry. PyTorch assigns sequential | ||
| // numeric names ("0", "1", ...) to process groups; probe until we | ||
| // find one with an NCCL backend. | ||
| if (this->group_name.empty() && this->is_md) { |
There was a problem hiding this comment.
We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select
|
|
||
| def forward(self, x): | ||
| out = self.linear(x) | ||
| out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name) |
There was a problem hiding this comment.
Lets dig into this more after the PR lands
| logger = logging.getLogger("torchtrtrun") | ||
|
|
||
|
|
||
| def _get_nccl_lib_dir() -> Optional[str]: |
There was a problem hiding this comment.
Move into its own file
|
|
||
| self._nccl_comm: Optional[Any] = None | ||
| self._has_nccl_ops: bool = False | ||
|
|
There was a problem hiding this comment.
this should be set before the self.setup_engine()
| inspector = self.engine.create_engine_inspector() | ||
| engine_json = inspector.get_engine_information(trt.LayerInformationFormat.JSON) | ||
| self._has_nccl_ops = "NCCL" in engine_json or "AllReduce" in engine_json | ||
|
|
There was a problem hiding this comment.
something like this works
engine_json_lower = engine_json.lower()
self._has_nccl_ops = "dist_collective" in engine_json_lower or "nccl" in engine_json_lower or "allreduce" in engine_json_lower
Description
Opening this to test the CI
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: