Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ option(ENABLE_GOOGLEBENCH "Enable GOOGLE-benchmark usage" OFF)
option(ENABLE_RAPIDJSON "Enable rapid-json usage" OFF)
option(ENABLE_CNPY "Enable cnpy usage" OFF)
option(ENABLE_CUSOLVERMP "Enable cusolvermp" OFF)
option(ENABLE_NCCL_PARALLEL_DEVICE "Enable NCCL-backed collectives in parallel_device" OFF)

if(NOT DEFINED NVHPC_ROOT_DIR AND DEFINED ENV{NVHPC_ROOT})
set(NVHPC_ROOT_DIR
Expand Down Expand Up @@ -451,6 +452,15 @@ if(USE_CUDA)
if (USE_OPENMP AND OpenMP_CXX_FOUND)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE)
endif()
if (ENABLE_NCCL_PARALLEL_DEVICE)
if (NOT ENABLE_MPI)
message(FATAL_ERROR
"ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.")
endif()
add_compile_definitions(__NCCL_PARALLEL_DEVICE)
include(cmake/SetupNccl.cmake)
abacus_setup_nccl(${ABACUS_BIN_NAME})
endif()
if (ENABLE_CUSOLVERMP)
# Keep cuSOLVERMp discovery/linking logic in a dedicated module.
include(cmake/SetupCuSolverMp.cmake)
Expand Down
49 changes: 49 additions & 0 deletions cmake/SetupNccl.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
include_guard(GLOBAL)

include(CheckIncludeFileCXX)

function(abacus_setup_nccl target_name)
find_library(NCCL_LIBRARY NAMES nccl
HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR}
PATH_SUFFIXES lib lib64 comm_libs/nccl/lib)
find_path(NCCL_INCLUDE_DIR NAMES nccl.h
HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR}
PATHS ${CUDAToolkit_ROOT}
PATH_SUFFIXES include comm_libs/nccl/include)

check_include_file_cxx("nccl.h" HAVE_NCCL_HEADER)

if(NOT NCCL_LIBRARY)
set(NCCL_LIBRARY nccl)
endif()

if(NOT NCCL_INCLUDE_DIR AND NOT HAVE_NCCL_HEADER)
message(FATAL_ERROR
"NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.")
endif()

if(NCCL_INCLUDE_DIR)
message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}")
else()
message(STATUS "Using default compiler/linker search paths for NCCL: ${NCCL_LIBRARY}")
endif()
if(NOT TARGET NCCL::NCCL)
add_library(NCCL::NCCL IMPORTED INTERFACE)
if(NCCL_INCLUDE_DIR)
set_target_properties(NCCL::NCCL PROPERTIES
INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}")
else()
set_target_properties(NCCL::NCCL PROPERTIES
INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}")
endif()
endif()

if(NCCL_INCLUDE_DIR)
# `parallel_device.cpp` is compiled inside the later `base` OBJECT library,
# so the header path must also be visible to targets created in subdirs.
include_directories(${NCCL_INCLUDE_DIR})
target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR})
endif()
target_link_libraries(${target_name} NCCL::NCCL)
endfunction()
31 changes: 17 additions & 14 deletions source/source_base/module_device/device_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,23 @@ static const char* _cufftGetErrorString(cufftResult_t error)
#define CHECK_CUDA_SYNC() do {} while (0)
#endif

// NCCL check macro: shared by cuSOLVER MP (non-CAL path) and parallel device
#if (defined(__CUSOLVERMP) && !defined(__USE_CAL)) || defined(__NCCL_PARALLEL_DEVICE)
#include <nccl.h>

#define CHECK_NCCL(func) \
do \
{ \
ncclResult_t status = (func); \
if (status != ncclSuccess) \
{ \
fprintf(stderr, "In File %s : NCCL API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
ncclGetErrorString(status), status); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif

// cuSOLVER MP support
#ifdef __CUSOLVERMP
#include <cusolverMp.h>
Expand Down Expand Up @@ -262,20 +279,6 @@ static const char* _calGetErrorString(calError_t error)
exit(EXIT_FAILURE); \
} \
} while (0)
#else // !__USE_CAL (use NCCL)
#include <nccl.h>

#define CHECK_NCCL(func) \
do \
{ \
ncclResult_t status = (func); \
if (status != ncclSuccess) \
{ \
fprintf(stderr, "In File %s : NCCL API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \
ncclGetErrorString(status), status); \
exit(EXIT_FAILURE); \
} \
} while (0)
#endif // __USE_CAL

#endif // __CUSOLVERMP
Expand Down
43 changes: 16 additions & 27 deletions source/source_base/para_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void PGemmCN<T, Device>::set_dimension(
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
{
resmem_dev_op()(C_local_tmp_, size_C_local);
#ifndef __CUDA_MPI
#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE)
C_global_tmp_.resize(size_C_global);
#endif
}
Expand Down Expand Up @@ -277,38 +277,27 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con

if (this->gatherC)
{
#ifdef __CUDA_MPI
T* Clocal_mpi = C_local;
T* Cglobal_mpi = C;
#else
T* Clocal_mpi = C_tmp_.data();
T* Cglobal_mpi = nullptr;
T* reduce_tmp = nullptr;
T* gather_tmp = nullptr;
#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE)
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
{
syncmem_d2h_op()(Clocal_mpi, C_local, size_C_local);
Cglobal_mpi = C_global_tmp_.data();
}
else
{
Cglobal_mpi = C;
reduce_tmp = C_tmp_.data();
gather_tmp = C_global_tmp_.data();
}
#endif
if (this->row_nproc > 1)
{
Parallel_Common::reduce_data(Clocal_mpi, size_C_local, row_world);
Parallel_Common::reduce_dev<T, Device>(C_local, size_C_local, row_world, reduce_tmp);
}
Parallel_Common::gatherv_data(Clocal_mpi,
size_C_local,
Cglobal_mpi,
recv_counts.data(),
displs.data(),
col_world);
#ifndef __CUDA_MPI
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
{
syncmem_h2d_op()(C, Cglobal_mpi, size_C_global);
}
#endif
Parallel_Common::gatherv_dev<T, Device>(C_local,
size_C_local,
C,
recv_counts.data(),
displs.data(),
col_world,
reduce_tmp,
gather_tmp);
}
else
{
Expand Down Expand Up @@ -409,4 +398,4 @@ template class PGemmCN<std::complex<double>, base_device::DEVICE_GPU>;
template class PGemmCN<std::complex<float>, base_device::DEVICE_GPU>;
#endif

} // namespace ModuleBase
} // namespace ModuleBase
Loading
Loading