Skip to content

Reduce and Allreduce NVLS implementations for the cuda backend#6038

Open
nsarka wants to merge 24 commits intoNVIDIA:mainfrom
nsarka:nsarka/reduce-nvls-cuda-backend
Open

Reduce and Allreduce NVLS implementations for the cuda backend#6038
nsarka wants to merge 24 commits intoNVIDIA:mainfrom
nsarka:nsarka/reduce-nvls-cuda-backend

Conversation

@nsarka
Copy link
Member

@nsarka nsarka commented Mar 20, 2026

Built on top of #5620. Adds reduce and allreduce NVLS implementations. Both use the same ld_reduce kernel and synchronize using a symmetric integer tensor as a semaphore

@nsarka nsarka requested review from samnordmann and wujingyue March 20, 2026 14:24
@nsarka nsarka force-pushed the nsarka/reduce-nvls-cuda-backend branch from 6cc186a to 2697b92 Compare March 20, 2026 14:25
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR adds Reduce and Allreduce implementations for the CUDA/NVLS backend, built on top of the existing NVLink SHARP (multimem.ld_reduce) multicast infrastructure. Both operations share the same multimem_ld_reduce_sum_f32_kernel (a new NVRTC-compiled kernel in runtime/multicast_reduce.cu) and use a symmetric integer tensor as a semaphore barrier, mirroring the existing Allgather pattern.

Key changes:

  • New kernel (runtime/multicast_reduce.cu): multimem_ld_reduce_sum_f32_kernel performs a 4×f32 vectorised load-reduce from the multicast VA into a destination buffer. Requires SM90+/PTX ISA 8.0. The 16-byte alignment and size constraints are enforced by all callers.
  • SymMemForAllreduce / SymMemForReduce (ipc_handle.h/.cpp): New SymmetricMemoryHandle subclasses that allocate symmetric input buffers and per-rank semaphore tensors, exposing unicast and multicast pointers for use by the post/wait functions.
  • postAllreduceWithCudaBackend / waitAllreduceWithCudaBackend (cuda_p2p.cpp): Four-step stream-ordered barrier (write ready → wait ready → ld_reduce kernel → write complete) identical in structure to Allgather.
  • postReduceWithCudaBackend / waitReduceWithCudaBackend (cuda_p2p.cpp): Root-only pattern — all ranks signal readiness, root waits, runs ld_reduce, then signals completion to non-roots; non-roots wait for the done signal.
  • HostIrEvaluator (evaluator.cpp): Extended CUDA backend dispatch to handle Allreduce and Reduce types; cache key falls back to input_tensor when output_tensor is undefined (non-root Reduce ranks).
  • IpcHandleCache (ipc_handle.h): Changed from storing const ExpressionEvaluator& to const ExpressionEvaluator*. This is safe at all current call sites but loses the compile-time lifetime guarantee a reference would provide.
  • Bug fix (std::array<void*, 9> in launchAlltoallvKernel): The refactor of raw C arrays to std::array introduced a size mismatch — 9 elements declared but only 8 initializers provided for an 8-parameter kernel. The extra nullptr is harmless to the CUDA driver but is technically outside the API contract.
  • tests/cpp/utils.cpp: One-line fix to set found_seed = true inside the env-var seed branch, preventing the seed from being silently ignored.

Confidence Score: 3/5

  • The new allreduce/reduce protocols are algorithmically correct and consistent with the existing allgather pattern, but a size mismatch in the kernel args array and the removal of reference-based lifetime safety warrant attention before merge.
  • The core NVLS allreduce/reduce semaphore protocol is sound and correctly mirrors the pre-existing allgather implementation. However, std::array<void*, 9> with only 8 initializers in launchAlltoallvKernel is a clear off-by-one that breaks the CUDA driver API contract (even if the driver is lenient in practice). The IpcHandleCache raw-pointer change removes a compile-time lifetime guarantee. Neither issue should cause a production crash in the current codebase, but both should be fixed before merge.
  • csrc/multidevice/cuda_p2p.cpp (std::array size mismatch in launchAlltoallvKernel) and csrc/multidevice/ipc_handle.h (raw pointer lifetime concern in IpcHandleCache).

Important Files Changed

Filename Overview
runtime/multicast_reduce.cu New CUDA kernel for multimem ld_reduce sum over float4 vectors. Requires SM90+/PTX ISA 8.0. Tail-byte handling is safe in the normal path (callers enforce 16-byte multiple alignment before reaching the kernel).
csrc/multidevice/cuda_p2p.cpp Adds postAllreduceWithCudaBackend, waitAllreduceWithCudaBackend, postReduceWithCudaBackend, waitReduceWithCudaBackend, and launchMulticastReduceKernel. Code cleanup (std::array, value-initialisation) is good, but std::array<void*, 9> has only 8 initializers for launchAlltoallvKernel, leaving a dangling nullptr in the kernel args array.
csrc/multidevice/ipc_handle.h Adds SymMemForAllreduce and SymMemForReduce classes; changes IpcHandleCache to take a raw pointer instead of a reference — safe in current call sites but removes compile-time lifetime protection.
csrc/multidevice/ipc_handle.cpp Implements SymMemForAllreduce and SymMemForReduce constructors; correctly allocates symmetric tensors for input buffers and semaphores, sets up unicast IPC handles, and optionally sets up multicast for Memcpy/Multimem protocols.
csrc/host_ir/evaluator.cpp Extends CUDA backend dispatch to Allreduce and Reduce; cache key correctly falls back to input_tensor when output_tensor is undefined (non-root Reduce ranks); passes output tensor through to postWithCudaBackend.

Sequence Diagram

sequenceDiagram
    participant R0 as Rank 0
    participant R1 as Rank 1
    participant Sym as Symmetric Buffer (NVLS Multicast VA)

    Note over R0,R1: postAllreduceWithCudaBackend / postReduceWithCudaBackend
    R0->>Sym: cudaMemcpyAsync(inputBuffer[0] ← input)
    R1->>Sym: cudaMemcpyAsync(inputBuffer[1] ← input)

    R0->>R1: cuStreamBatchMemOp WRITE kInProgress → R1.semaphore[0]
    R1->>R0: cuStreamBatchMemOp WRITE kInProgress → R0.semaphore[1]

    R0->>R0: cuStreamBatchMemOp WAIT kInProgress on R1.semaphore[0]
    R1->>R1: cuStreamBatchMemOp WAIT kInProgress on R0.semaphore[1]

    R0->>Sym: launchMulticastReduceKernel(mc_ptr → output)
    R1->>Sym: launchMulticastReduceKernel(mc_ptr → output)
    Note over Sym: multimem.ld_reduce aggregates all ranks' inputBuffers

    R0->>R1: cuStreamBatchMemOp WRITE kIdle → R1.semaphore[0]
    R1->>R0: cuStreamBatchMemOp WRITE kIdle → R0.semaphore[1]

    Note over R0,R1: waitAllreduceWithCudaBackend / waitReduceWithCudaBackend
    R0->>R0: cuStreamBatchMemOp WAIT kIdle on R0.semaphore[1]
    R1->>R1: cuStreamBatchMemOp WAIT kIdle on R1.semaphore[0]
Loading

Comments Outside Diff (2)

  1. csrc/multidevice/cuda_p2p.cpp, line 177-186 (link)

    std::array<void*, 9> has only 8 initializers

    The array is declared with size 9 but only 8 elements are provided. This means the 9th element is value-initialized to nullptr. The cuLaunchKernel API specification requires exactly N pointers where N is the number of kernel parameters — passing an extra nullptr is technically outside the contract (though the CUDA driver happens to read exactly the first 8 in practice). The original C-style array had 8 elements matching the 8 kernel parameters (send, recv_ptrs, send_offsets, send_sizes, recv_offsets, world_size, elem_size, max_send_bytes).

  2. csrc/multidevice/ipc_handle.h, line 108-110 (link)

    Raw pointer introduces silent lifetime risk

    The constructor was changed from accepting const ExpressionEvaluator& (a reference, which cannot be null and guarantees the referred-to object exists at construction time) to const ExpressionEvaluator* (a raw pointer). The stored pointer at line 167 provides no lifetime guarantee — if an IpcHandleCache is ever constructed with nullptr, or if HostIrEvaluator is moved (invalidating the address of its member expr_evaluator_ that was passed as &expr_evaluator_), all subsequent dereferences will be UB.

    All current call sites are safe (&expr_evaluator_ from a member, &expr_evaluator from a local), but a reference (or std::reference_wrapper / gsl::not_null) would enforce correctness at the type level. Consider using const ExpressionEvaluator& again, or adding an NVF_ERROR(expr_evaluator != nullptr) assertion in the constructor body.

Last reviewed commit: "lint"

Comment on lines +25 to +41
size_t n_vec = n_bytes / 16;

for (size_t i = idx; i < n_vec; i += stride) {
float r0, r1, r2, r3;
const void* addr = mc_src_c + i * 16;
asm volatile(
"multimem.ld_reduce.global.add.v4.f32 {%0,%1,%2,%3}, [%4];"
: "=f"(r0), "=f"(r1), "=f"(r2), "=f"(r3)
: "l"(addr)
: "memory");
float4 out;
out.x = r0;
out.y = r1;
out.z = r2;
out.w = r3;
((float4*)dst_c)[i] = out;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tail elements silently dropped when n_bytes is not a 16-byte multiple

The kernel computes n_vec = n_bytes / 16 using integer division, so any bytes in the range [n_vec*16, n_bytes) are silently ignored and never reduced. The caller (launchMulticastReduceKernelImpl) does check size % 16 == 0 via NVF_CHECK, so this cannot be triggered through the normal code path. However, the public launchMulticastReduceKernel wrapper (used by tests, per the header comment) does not perform that check, leaving callers free to pass a size that is not a multiple of 16 and silently get wrong results.

Consider adding the alignment assertion inside the kernel or inside launchMulticastReduceKernel itself so test callers also get a clear error.

nsarka and others added 5 commits March 20, 2026 16:40
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
@nsarka
Copy link
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Member Author

nsarka commented Mar 20, 2026

!test

@nsarka
Copy link
Member Author

nsarka commented Mar 20, 2026

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant