multidevice: enable NCCL Copy Engine allgather via CTAPolicy=ZERO#6046
multidevice: enable NCCL Copy Engine allgather via CTAPolicy=ZERO#6046saivishal1999 wants to merge 1 commit intomainfrom
Conversation
When `SymmetricMemoryBackend::PyTorchNccl` is active, configure the ProcessGroupNCCL with `NCCL_CTA_POLICY_ZERO` at creation time. Combined with NCCL-window-registered buffers (allocated via `empty_strided_p2p` and rendezvous'd through `SymmetricTensor::setupRemoteHandles`), this causes NCCL to select the Copy Engine (DMA) path for allgather, freeing SM/CTA resources. Add `SymmetricTensorTest.CopyEngineAllgather` to verify correctness of an allgather over symm_mem buffers on all 8 ranks. Confirmed via NCCL logs: - `Comm config CTA policy flags set to 2` on each rank - `Inserted window ... into address map` for both input/output buffers - `AllGather [Copy Engine]: ... -> cudaMemcpy; CE synchronization with NVLS`
Greptile SummaryConfigures each
Confidence Score: 4/5Safe to merge; the CTA policy change is opt-in via the PyTorchNccl backend flag and correctness is preserved. The change is small and well-guarded by #ifdef NCCL_HAS_CTA_POLICY. The one thing worth a second look is that CTAPolicy=ZERO is stamped onto every NCCL PG created while PyTorchNccl is active, not only the world communicator. Workloads mixing symmetric-memory allgather with sub-team collectives may see performance changes on those sub-team groups. csrc/multidevice/communicator.cpp — the scope of which PGs receive CTAPolicy=ZERO deserves a brief look before merging into workloads that use multi-team parallelism strategies. Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as CopyEngineAllgather Test
participant Alloc as SymmetricTensor::allocate
participant Backend as createBackend (NCCL PG)
participant Rdv as symmetric_memory::rendezvous
participant NCCL as _allgather_base (CE path)
Test->>Alloc: allocate(input shape)
Alloc->>Backend: initSymmMemBackendAndGetGroup
Backend-->>Backend: CTAPolicy=ZERO (NCCL_HAS_CTA_POLICY guard)
Test->>Alloc: allocate(output shape)
Note over Alloc: returns cached PG
Test->>Rdv: input_sym.setupRemoteHandles
Note over Rdv: NCCL window-registers input buffer
Test->>Rdv: output_sym.setupRemoteHandles
Note over Rdv: NCCL window-registers output buffer
Test->>Test: input.fill_(rank + 1)
Test->>NCCL: _allgather_base(output, input)
Note over NCCL: CTAPolicy=ZERO + window-registered = CE DMA path
NCCL-->>Test: work->wait()
Test->>Test: validate output slices
Reviews (1): Last reviewed commit: "multidevice: enable NCCL Copy Engine all..." | Re-trigger Greptile |
| #ifdef NCCL_HAS_CTA_POLICY | ||
| if (getSymmetricMemoryBackend() == SymmetricMemoryBackend::PyTorchNccl) { | ||
| pg_opts->config.CTAPolicy = NCCL_CTA_POLICY_ZERO; | ||
| } |
There was a problem hiding this comment.
CTAPolicy=ZERO applied to every NCCL PG, not just the world PG
createBackend is called lazily by getBackendForTeam whenever any new team key is created — including sub-team PGs (e.g., tensor-parallel or pipeline-parallel groups). When PyTorchNccl is active, every future NCCL process group will be created with CTAPolicy=ZERO, not only the world communicator used for the symmetric-memory allgather. Operations on those sub-team groups (allreduce, reduce-scatter, broadcast, etc.) that cannot use the CE path will be executed under the ZERO-CTA policy, which may meaningfully change their latency/throughput relative to the default policy.
| Team all_ranks(world_size); | ||
| std::iota(all_ranks.begin(), all_ranks.end(), 0); | ||
| c10d::Backend* backend = | ||
| communicator_->getBackendForTeam(all_ranks, CommunicatorBackend::kNccl); |
There was a problem hiding this comment.
The test manually replicates the logic inside
Communicator::getWorld() — constructing a full-rank Team and calling getBackendForTeam. Using getWorld directly is cleaner and is the existing idiom in the codebase.
| Team all_ranks(world_size); | |
| std::iota(all_ranks.begin(), all_ranks.end(), 0); | |
| c10d::Backend* backend = | |
| communicator_->getBackendForTeam(all_ranks, CommunicatorBackend::kNccl); | |
| c10d::Backend* backend = communicator_->getWorld(CommunicatorBackend::kNccl); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
When
SymmetricMemoryBackend::PyTorchNcclis active, configure the ProcessGroupNCCL withNCCL_CTA_POLICY_ZEROat creation time. Combined with NCCL-window-registered buffers (allocated viaempty_strided_p2pand rendezvous'd throughSymmetricTensor::setupRemoteHandles), this causes NCCL to select the Copy Engine (DMA) path for allgather, freeing SM/CTA resources.Add
SymmetricTensorTest.CopyEngineAllgatherto verify correctness of an allgather over symm_mem buffers on all 8 ranks. Confirmed via NCCL logs:Comm config CTA policy flags set to 2on each rankInserted window ... into address mapfor both input/output buffersAllGather [Copy Engine]: ... -> cudaMemcpy; CE synchronization with NVLS