Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension:
f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also explicit nvshmem usage in transformer_engine/common/nvshmem_api - I assume this removal would result in failure of that functionality?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, Nvshmem discovery / CMake knobs for cuBLASMp (which this PR removes) and for nvshmem_api is independent.

f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
Expand Down
10 changes: 4 additions & 6 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -283,20 +283,18 @@ endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS ${NVSHMEM_DIR}
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
Comment on lines +292 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL library not discoverable

find_library(NCCL_LIB ...) is missing a PATHS/hint variable (unlike CUBLASMP_LIB, which uses ${CUBLASMP_DIR}). Unless NCCL is already in the default linker search paths, enabling NVTE_WITH_CUBLASMP will fail at configure time with NCCL_LIB not found. This PR should add a way to point CMake at NCCL (env var / CMake cache var) and pass it from setup.py similarly to CUBLASMP_DIR.

message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()

# Hack to enable dynamic loading in cuDNN frontend
Expand Down
62 changes: 35 additions & 27 deletions transformer_engine/common/comm_gemm/comm_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>

#include <map>
#include <memory>
Expand Down Expand Up @@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
ctx->grid_row_major.get(), ctx->d_desc.get()));

const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
Expand Down Expand Up @@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo

const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a,
sizeof trans_a));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b,
sizeof trans_b));
cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr,
sizeof algo_attr));

const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a->dtype())) {
NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(b->dtype())) {
NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(d->dtype())) {
NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER,
&d->scale.dptr, sizeof(void*)));
if (d->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER,
&d->amax.dptr, sizeof(void*)));
}
Expand All @@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t epilogue{};
size_t size_read{};
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue, &size_read));
NVTE_CHECK(size_read == sizeof epilogue);
Expand All @@ -339,55 +338,55 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad});
it != flags_to_epilogue.end()) {
epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}

if (bias && bias->data.dptr) {
cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type,
sizeof bias_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr,
sizeof bias->data.dptr));
}

if (pre_act_out && pre_act_out->data.dptr) {
cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof aux_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER,
&pre_act_out->data.dptr, sizeof pre_act_out->data.dptr));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd,
sizeof ldd));
if (is_fp8_dtype(pre_act_out->dtype())) {
NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE,
&scale_mode, sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER,
&pre_act_out->scale.dptr, sizeof(void*)));
if (pre_act_out->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER,
&pre_act_out->amax.dptr, sizeof(void*)));
}
}
}

if (comm_sm_count) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT,
&comm_sm_count, sizeof comm_sm_count));
}

NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream));
NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream));

size_t wrksp_size_device{};
size_t wrksp_size_host{};
Expand Down Expand Up @@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo

std::vector<uint8_t> workspace_host(wrksp_size_host);
if (ctx->workspace_size < wrksp_size_device) {
nvshmem_free(ctx->workspace);
ctx->workspace = nvshmem_malloc(wrksp_size_device);
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
NVTE_CHECK_CUBLASMP(
cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device));
NVTE_CHECK_CUBLASMP(
cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device));
ctx->workspace_size = wrksp_size_device;
}

Expand Down Expand Up @@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank

void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get());
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx;
Comment on lines 479 to 485
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe workspace free

nvte_comm_gemm_ctx_destroy can deregister/free ctx->workspace while cublasMpMatmul work on the user-provided main_stream is still in flight. cublasmp_gemm sets the cuBLASMp handle stream to main_stream (comm_gemm.cpp:389) and uses ctx->workspace in the enqueue (comm_gemm.cpp:436-439), but destroy() does not synchronize main_stream (or otherwise ensure completion) before calling cublasMpBufferDeregister/cublasMpFree.

This can become a use-after-free if a caller destroys the ctx shortly after launching a comm GEMM. Either synchronize the relevant stream(s) before freeing, or explicitly document (in comm_gemm.h) that callers must synchronize main_stream before calling destroy().

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);

Expand Down
Loading