-
Notifications
You must be signed in to change notification settings - Fork 632
Get rid of nvshmem dependency for cuBLASMp integration #2661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR updates TransformerEngine’s cuBLASMp integration to reflect cuBLASMp >= 0.8.0 moving symmetric-memory support from nvshmem to NCCL. It removes the nvshmem build dependency, updates CMake linkage to bring in NCCL for the cuBLASMp path, and adjusts the comm GEMM implementation and its public header to use the updated cuBLASMp APIs. I did not find additional fix-required issues in the changed files beyond the two items already raised in prior threads (NCCL discoverability in CMake, and workspace lifetime/synchronization in ctx destroy). Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as setup.py (build)
participant CM as CMakeLists.txt
participant CG as comm_gemm.cpp
participant MP as cuBLASMp
participant NC as NCCL
Py->>CM: Configure NVTE_WITH_CUBLASMP and dependency paths
CM->>CM: Link cuBLASMp + NCCL (nvshmem removed)
CG->>MP: Create/Configure cuBLASMp handle
CG->>NC: Initialize/associate NCCL-backed symmetric memory
CG->>MP: Register/deregister workspace buffer (new API)
CG->>MP: Enqueue cublasMpMatmul on main_stream
CG-->>MP: Destroy ctx/handle (must ensure stream work complete)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 4 comments
| find_library(NCCL_LIB | ||
| NAMES nccl libnccl | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) | ||
| target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NCCL library not discoverable
find_library(NCCL_LIB ...) is missing a PATHS/hint variable (unlike CUBLASMP_LIB, which uses ${CUBLASMP_DIR}). Unless NCCL is already in the default linker search paths, enabling NVTE_WITH_CUBLASMP will fail at configure time with NCCL_LIB not found. This PR should add a way to point CMake at NCCL (env var / CMake cache var) and pass it from setup.py similarly to CUBLASMP_DIR.
Additional Comments (3)
In the
After switching from |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { | ||
| NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); | ||
| nvshmemx_sync_all_on_stream(ctx->stream.get()); | ||
| if (ctx->workspace) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); | ||
| } | ||
| delete ctx; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe workspace free
nvte_comm_gemm_ctx_destroy can deregister/free ctx->workspace while cublasMpMatmul work on the user-provided main_stream is still in flight. cublasmp_gemm sets the cuBLASMp handle stream to main_stream (comm_gemm.cpp:389) and uses ctx->workspace in the enqueue (comm_gemm.cpp:436-439), but destroy() does not synchronize main_stream (or otherwise ensure completion) before calling cublasMpBufferDeregister/cublasMpFree.
This can become a use-after-free if a caller destroys the ctx shortly after launching a comm GEMM. Either synchronize the relevant stream(s) before freeing, or explicitly document (in comm_gemm.h) that callers must synchronize main_stream before calling destroy().
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Description
Starting with cuBLASMp 0.8.0, they're moving away from using nvshmem for symmetric memory, use NCCL instead.
This change adapts the to changed API.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: