-
Notifications
You must be signed in to change notification settings - Fork 634
Get rid of nvshmem dependency for cuBLASMp integration #2661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5b2f7cc
aadb720
80c2d1a
7ba0971
fd02d57
706d9d1
a98e550
2d3a3fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -283,20 +283,18 @@ endif() | |
| option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) | ||
| if (NVTE_WITH_CUBLASMP) | ||
| target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) | ||
| target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) | ||
| target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) | ||
| find_library(CUBLASMP_LIB | ||
| NAMES cublasmp libcublasmp | ||
| PATHS ${CUBLASMP_DIR} | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| find_library(NVSHMEM_HOST_LIB | ||
| NAMES nvshmem_host libnvshmem_host.so.3 | ||
| PATHS ${NVSHMEM_DIR} | ||
| find_library(NCCL_LIB | ||
| NAMES nccl libnccl | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) | ||
| target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) | ||
|
Comment on lines
+292
to
+296
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NCCL library not discoverable
|
||
| message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") | ||
| message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") | ||
| endif() | ||
|
|
||
| # Hack to enable dynamic loading in cuDNN frontend | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,6 @@ | |
|
|
||
| #include <cublasmp.h> | ||
| #include <cuda_runtime.h> | ||
| #include <nvshmem.h> | ||
|
|
||
| #include <map> | ||
| #include <memory> | ||
|
|
@@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n | |
| ctx->grid_row_major.get(), ctx->d_desc.get())); | ||
|
|
||
| const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, | ||
| sizeof epilogue)); | ||
| } | ||
|
|
@@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo | |
|
|
||
| const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; | ||
| const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, | ||
| sizeof trans_a)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, | ||
| sizeof trans_b)); | ||
| cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, | ||
| sizeof algo_attr)); | ||
|
|
||
| const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; | ||
| if (is_fp8_dtype(a->dtype())) { | ||
| NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, | ||
| sizeof scale_mode)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, | ||
| &a->scale_inv.dptr, sizeof(void*))); | ||
| } | ||
| if (is_fp8_dtype(b->dtype())) { | ||
| NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, | ||
| sizeof scale_mode)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, | ||
| &b->scale_inv.dptr, sizeof(void*))); | ||
| } | ||
| if (is_fp8_dtype(d->dtype())) { | ||
| NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, | ||
| sizeof scale_mode)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, | ||
| &d->scale.dptr, sizeof(void*))); | ||
| if (d->amax.dptr) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, | ||
| &d->amax.dptr, sizeof(void*))); | ||
| } | ||
|
|
@@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo | |
| // Might be set to ALLREDUCE before, need to OR with the new flags to set. | ||
| cublasMpMatmulEpilogue_t epilogue{}; | ||
| size_t size_read{}; | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, | ||
| sizeof epilogue, &size_read)); | ||
| NVTE_CHECK(size_read == sizeof epilogue); | ||
|
|
@@ -339,55 +338,55 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo | |
| pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); | ||
| it != flags_to_epilogue.end()) { | ||
| epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, | ||
| sizeof epilogue)); | ||
| } | ||
|
|
||
| if (bias && bias->data.dptr) { | ||
| cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, | ||
| sizeof bias_type)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, | ||
| sizeof bias->data.dptr)); | ||
| } | ||
|
|
||
| if (pre_act_out && pre_act_out->data.dptr) { | ||
| cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, | ||
| &aux_type, sizeof aux_type)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, | ||
| &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, | ||
| sizeof ldd)); | ||
| if (is_fp8_dtype(pre_act_out->dtype())) { | ||
| NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, | ||
| &scale_mode, sizeof scale_mode)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, | ||
| &pre_act_out->scale.dptr, sizeof(void*))); | ||
| if (pre_act_out->amax.dptr) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, | ||
| &pre_act_out->amax.dptr, sizeof(void*))); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if (comm_sm_count) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( | ||
| NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( | ||
| ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, | ||
| &comm_sm_count, sizeof comm_sm_count)); | ||
| } | ||
|
|
||
| NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream)); | ||
|
|
||
| size_t wrksp_size_device{}; | ||
| size_t wrksp_size_host{}; | ||
|
|
@@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo | |
|
|
||
| std::vector<uint8_t> workspace_host(wrksp_size_host); | ||
| if (ctx->workspace_size < wrksp_size_device) { | ||
| nvshmem_free(ctx->workspace); | ||
| ctx->workspace = nvshmem_malloc(wrksp_size_device); | ||
| if (ctx->workspace) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); | ||
| } | ||
| NVTE_CHECK_CUBLASMP( | ||
| cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); | ||
| NVTE_CHECK_CUBLASMP( | ||
| cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); | ||
| ctx->workspace_size = wrksp_size_device; | ||
| } | ||
|
|
||
|
|
@@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank | |
|
|
||
| void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { | ||
| NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); | ||
| nvshmemx_sync_all_on_stream(ctx->stream.get()); | ||
| if (ctx->workspace) { | ||
| NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); | ||
| NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); | ||
| } | ||
| delete ctx; | ||
|
Comment on lines
479
to
485
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsafe workspace free
This can become a use-after-free if a caller destroys the ctx shortly after launching a comm GEMM. Either synchronize the relevant stream(s) before freeing, or explicitly document (in |
||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also explicit nvshmem usage in transformer_engine/common/nvshmem_api - I assume this removal would result in failure of that functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, Nvshmem discovery / CMake knobs for cuBLASMp (which this PR removes) and for nvshmem_api is independent.