diff --git a/setup.py b/setup.py index 18bb736f24..3a66e624e3 100644 --- a/setup.py +++ b/setup.py @@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension: f"nvidia-cublasmp-cu{cuda_version()[0]}" ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") - nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - f"nvidia-nvshmem-cu{cuda_version()[0]}" - ).locate_file("nvidia/nvshmem") - cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") - print("CMAKE_FLAGS:", cmake_flags[-2:]) # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a83cbe3e30..17cdd0ec74 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -283,20 +283,18 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) find_library(CUBLASMP_LIB NAMES cublasmp libcublasmp PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) - find_library(NVSHMEM_HOST_LIB - NAMES nvshmem_host libnvshmem_host.so.3 - PATHS ${NVSHMEM_DIR} + find_library(NCCL_LIB + NAMES nccl libnccl PATH_SUFFIXES lib REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") - message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() # Hack to enable dynamic loading in cuDNN frontend diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 66a3da55dd..7be3d1bb4d 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include @@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n ctx->grid_row_major.get(), ctx->d_desc.get())); const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } @@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, sizeof trans_a)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, sizeof trans_b)); cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; if (is_fp8_dtype(a->dtype())) { NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, &a->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(b->dtype())) { NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, &b->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(d->dtype())) { NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, &d->scale.dptr, sizeof(void*))); if (d->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, &d->amax.dptr, sizeof(void*))); } @@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo // Might be set to ALLREDUCE before, need to OR with the new flags to set. cublasMpMatmulEpilogue_t epilogue{}; size_t size_read{}; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue, &size_read)); NVTE_CHECK(size_read == sizeof epilogue); @@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); it != flags_to_epilogue.end()) { epilogue = static_cast(epilogue | it->second); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } if (bias && bias->data.dptr) { cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, sizeof bias_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, sizeof bias->data.dptr)); } if (pre_act_out && pre_act_out->data.dptr) { cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof aux_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, sizeof ldd)); if (is_fp8_dtype(pre_act_out->dtype())) { NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, &pre_act_out->scale.dptr, sizeof(void*))); if (pre_act_out->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, &pre_act_out->amax.dptr, sizeof(void*))); } @@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo } if (comm_sm_count) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, &comm_sm_count, sizeof comm_sm_count)); } - NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream)); size_t wrksp_size_device{}; size_t wrksp_size_host{}; @@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - nvshmem_free(ctx->workspace); - ctx->workspace = nvshmem_malloc(wrksp_size_device); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } + NVTE_CHECK_CUBLASMP( + cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP( + cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } @@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); - nvshmemx_sync_all_on_stream(ctx->stream.get()); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } delete ctx; } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 06b56789a3..65d3aa5d9e 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank /*! \brief Destroy a comm-gemm context. * * \param[in] ctx Context to destroy. + * + * It's the caller's responsibility to synchronize all streams involved before calling this function. */ void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c542afa393..8031e342e2 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -96,12 +96,12 @@ #ifdef NVTE_WITH_CUBLASMP -#define NVTE_CHECK_CUBLASMP(expr) \ - do { \ - const cublasMpStatus_t status = (expr); \ - if (status != CUBLASMP_STATUS_SUCCESS) { \ - NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ - } \ +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \ + } \ } while (false) #endif // NVTE_WITH_CUBLASMP