From 5b2f7cc190ab76257e8e63d9f7e0a4d0940d77ee Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 27 Jan 2026 02:36:05 +0000 Subject: [PATCH 1/9] Remove nvshmem usage Signed-off-by: Vladimir Cherepanov --- setup.py | 5 ----- transformer_engine/common/CMakeLists.txt | 8 +------- transformer_engine/common/comm_gemm/comm_gemm.cpp | 8 ++++---- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 18bb736f24..3a66e624e3 100644 --- a/setup.py +++ b/setup.py @@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension: f"nvidia-cublasmp-cu{cuda_version()[0]}" ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") - nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - f"nvidia-nvshmem-cu{cuda_version()[0]}" - ).locate_file("nvidia/nvshmem") - cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") - print("CMAKE_FLAGS:", cmake_flags[-2:]) # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a83cbe3e30..0848f61656 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -289,14 +289,8 @@ if (NVTE_WITH_CUBLASMP) PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) - find_library(NVSHMEM_HOST_LIB - NAMES nvshmem_host libnvshmem_host.so.3 - PATHS ${NVSHMEM_DIR} - PATH_SUFFIXES lib - REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") - message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() # Hack to enable dynamic loading in cuDNN frontend diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 66a3da55dd..796660b054 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include @@ -423,8 +422,10 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - nvshmem_free(ctx->workspace); - ctx->workspace = nvshmem_malloc(wrksp_size_device); + NVTE_CHECK_CUBLASMP(cublasMpDeregister(ctx->grid_row_major, ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grod_col_major, ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpMalloc(ctx->grid_col_major, &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP(cublasMpRegister(ctx->grid_row_major, ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } @@ -473,7 +474,6 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); - nvshmemx_sync_all_on_stream(ctx->stream.get()); delete ctx; } From aadb720a68321db61db335712dfa3901e28dfa96 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 27 Jan 2026 03:14:01 +0000 Subject: [PATCH 2/9] Renamings Signed-off-by: Vladimir Cherepanov --- .../common/comm_gemm/comm_gemm.cpp | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 796660b054..bc52f96a90 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -235,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n ctx->grid_row_major.get(), ctx->d_desc.get())); const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } @@ -272,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, sizeof trans_a)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, sizeof trans_b)); cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; if (is_fp8_dtype(a->dtype())) { NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, &a->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(b->dtype())) { NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, &b->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(d->dtype())) { NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, &d->scale.dptr, sizeof(void*))); if (d->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, &d->amax.dptr, sizeof(void*))); } @@ -320,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo // Might be set to ALLREDUCE before, need to OR with the new flags to set. cublasMpMatmulEpilogue_t epilogue{}; size_t size_read{}; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue, &size_read)); NVTE_CHECK(size_read == sizeof epilogue); @@ -338,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); it != flags_to_epilogue.end()) { epilogue = static_cast(epilogue | it->second); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } if (bias && bias->data.dptr) { cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, sizeof bias_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, sizeof bias->data.dptr)); } if (pre_act_out && pre_act_out->data.dptr) { cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof aux_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, sizeof ldd)); if (is_fp8_dtype(pre_act_out->dtype())) { NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, &pre_act_out->scale.dptr, sizeof(void*))); if (pre_act_out->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, &pre_act_out->amax.dptr, sizeof(void*))); } @@ -381,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo } if (comm_sm_count) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, &comm_sm_count, sizeof comm_sm_count)); } - NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream)); size_t wrksp_size_device{}; size_t wrksp_size_host{}; @@ -422,10 +422,10 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - NVTE_CHECK_CUBLASMP(cublasMpDeregister(ctx->grid_row_major, ctx->workspace)); - NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grod_col_major, ctx->workspace)); - NVTE_CHECK_CUBLASMP(cublasMpMalloc(ctx->grid_col_major, &ctx->workspace, wrksp_size_device)); - NVTE_CHECK_CUBLASMP(cublasMpRegister(ctx->grid_row_major, ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP(cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } From 80c2d1aaba4c7a88a8a72efd7df34c352a659a37 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 27 Jan 2026 03:15:49 +0000 Subject: [PATCH 3/9] NCCL dependency Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 0848f61656..1a800daaf5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -289,7 +289,11 @@ if (NVTE_WITH_CUBLASMP) PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB}) + find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") endif() From 7ba09715cfab8c9e9ca652f5fdbd3b2a2970809f Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 2 Feb 2026 22:56:41 +0000 Subject: [PATCH 4/9] Check for not yet allocated workspace Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/comm_gemm/comm_gemm.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index bc52f96a90..c10798dd3c 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -422,8 +422,10 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); - NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } NVTE_CHECK_CUBLASMP(cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); NVTE_CHECK_CUBLASMP(cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; From fd02d5771ba0e22ac3050e9076102ddd3d519e81 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 21:30:34 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/comm_gemm/comm_gemm.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index c10798dd3c..2feb62b6ba 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -426,8 +426,10 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); } - NVTE_CHECK_CUBLASMP(cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); - NVTE_CHECK_CUBLASMP(cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP( + cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP( + cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } From 706d9d1192364637e68660baf439d14eed14f0a6 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 9 Feb 2026 14:22:37 -0800 Subject: [PATCH 6/9] Address greptile comments Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/CMakeLists.txt | 2 +- transformer_engine/common/comm_gemm/comm_gemm.cpp | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 1a800daaf5..17cdd0ec74 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -283,7 +283,7 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) find_library(CUBLASMP_LIB NAMES cublasmp libcublasmp PATHS ${CUBLASMP_DIR} diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 2feb62b6ba..7be3d1bb4d 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -478,6 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } delete ctx; } From a98e550e68ef7d9c8e88fd6a759418a70dda8c88 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Mon, 9 Feb 2026 14:40:34 -0800 Subject: [PATCH 7/9] Add a comment per greptile Signed-off-by: Vladimir Cherepanov --- .../common/include/transformer_engine/comm_gemm.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 06b56789a3..083a1f34a2 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank /*! \brief Destroy a comm-gemm context. * * \param[in] ctx Context to destroy. + * + * It's the caller's respondibility to synchronize all streams involved before calling this function. */ void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); From 2d3a3faa71c6a6d0c8eb102affe14f1d79031895 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Tue, 10 Feb 2026 13:32:38 -0800 Subject: [PATCH 8/9] Fix a typo Signed-off-by: Vladimir Cherepanov --- .../common/include/transformer_engine/comm_gemm.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 083a1f34a2..65d3aa5d9e 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -56,7 +56,7 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank * * \param[in] ctx Context to destroy. * - * It's the caller's respondibility to synchronize all streams involved before calling this function. + * It's the caller's responsibility to synchronize all streams involved before calling this function. */ void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); From 95a271c697d09d2247728b92c4ad27c2279fc89b Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov Date: Wed, 11 Feb 2026 11:17:49 -0800 Subject: [PATCH 9/9] Display human-readable cuBLASMp error message Signed-off-by: Vladimir Cherepanov --- transformer_engine/common/util/logging.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c542afa393..8031e342e2 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -96,12 +96,12 @@ #ifdef NVTE_WITH_CUBLASMP -#define NVTE_CHECK_CUBLASMP(expr) \ - do { \ - const cublasMpStatus_t status = (expr); \ - if (status != CUBLASMP_STATUS_SUCCESS) { \ - NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ - } \ +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \ + } \ } while (false) #endif // NVTE_WITH_CUBLASMP