Skip to content

Commit d6f3030

Browse files
JohannesGaesslerIMbackKgaugarg-nvggerganov
authored
ggml: backend-agnostic tensor parallelism (experimental) (ggml-org#19378)
* ggml: backend-agnostic tensor parallelism * support for GPT-OSS, Qwen 3 MoE * partial Vulkan fix * add support for 4/8 GPUs * unconditional peer access * re-use buffers + ggml contexts * fix output pattern * NCCL support * GGML: HIP: add RCCL support * Remove shfl and AllReduce from backend interface * move allocation workaround out of ggml-alloc.c * 2d tensor set/get support * Fix the seg fault without NCCL * Apply suggestion from JohannesGaessler * support for tensor dims % n_devs != 0 * fix view_offs scaling * arbitrary num. of GPUs/tensor split * fix compilation * better granularity estimate * Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA. Fix compilation errors. * partial Qwen 3 Next support * Fix qwen3 30b (ggml-org#8) * Fix crash with Qwen-30B-A3B Q4_0 Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation. * Decide block size based on tensor quantization type * Fix crashes due to KV cache serialization (ggml-org#9) KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset. * metal : fix build (ggml-org#7) * static memory allocations, fix usage count * fix tensor granularity * more even memory distribution * use BF16 for allreduce * rebase fixup * better error message for unsupported architectures * Fix device mismatch during scatter of allReduce. (ggml-org#11) There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies * Enable the previous allreduce implementation. It is better in both perf and stability (ggml-org#12) * delay AllReduce for Moe for less I/O * build : clean-up compile warnings * backend : move most of the meta backend API to ggml-backend-impl.h * cont : hide unused public API in the implementation * llama : use llama_device + remove ggml_backend_dev_is_meta() * ggml-backend : remove unused alloc include * minor : remove regex include * ggml : introduce ggml-ext.h for staging new APIs * rebase fixup * fix tests * llama : more robust logic for determining Meta devices (ggml-org#16) * llama : more robust logic for determining Meta devices * cont : fix devs size check Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cont : fix log type Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * disable roundtrip for meta backend * fix arch selection * Qwen 3.5 support * fix Gemma 4 MoE * fix OpenVino, SYCL * fix test-llama-archs for CPU-only builds * Fix Qwen 3.5 MoE * disable meta backend tests for WebGPU * tests : filter CPU-based devices from the Meta backend tests (ggml-org#17) * meta : formatting, naming, indentation (ggml-org#18) * formatting : llama-model.cpp * formatting : ggml-ext.h * formatting : ggml-backend-meta.cpp * meta : add TODO * add documentation * better error messages * fix GPT-OSS --------- Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz> Co-authored-by: Gaurav Garg <gaugarg@nvidia.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 009a113 commit d6f3030

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+3197
-341
lines changed

common/arg.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,19 +2348,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23482348
}
23492349
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
23502350
add_opt(common_arg(
2351-
{"-sm", "--split-mode"}, "{none,layer,row}",
2351+
{"-sm", "--split-mode"}, "{none,layer,row,tensor}",
23522352
"how to split the model across multiple GPUs, one of:\n"
23532353
"- none: use one GPU only\n"
2354-
"- layer (default): split layers and KV across GPUs\n"
2355-
"- row: split rows across GPUs",
2354+
"- layer (default): split layers and KV across GPUs (pipelined)\n"
2355+
"- row: split weight across GPUs by rows (parallelized)\n"
2356+
"- tensor: split weights and KV across GPUs (parallelized)",
23562357
[](common_params & params, const std::string & value) {
2357-
std::string arg_next = value;
2358-
if (arg_next == "none") {
2358+
if (value == "none") {
23592359
params.split_mode = LLAMA_SPLIT_MODE_NONE;
2360-
} else if (arg_next == "layer") {
2360+
} else if (value == "layer") {
23612361
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
2362-
} else if (arg_next == "row") {
2362+
} else if (value == "row") {
23632363
params.split_mode = LLAMA_SPLIT_MODE_ROW;
2364+
} else if (value == "tensor") {
2365+
params.split_mode = LLAMA_SPLIT_MODE_TENSOR;
23642366
} else {
23652367
throw std::invalid_argument("invalid value");
23662368
}

ggml/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9)
77
set(GGML_VERSION_PATCH 11)
88
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
99

10+
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
11+
1012
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
1113
if(GIT_EXE)
1214
# Get current git commit hash
@@ -204,12 +206,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
204206
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
205207
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
206208
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
209+
option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON)
207210
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
208211
"ggml: cuda link binary compression mode; requires cuda 12.8+")
209212
set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")
210213

211214
option(GGML_HIP "ggml: use HIP" OFF)
212215
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
216+
option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF)
213217
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
214218
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
215219
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)

ggml/cmake/FindNCCL.cmake

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# cmake/FindNCCL.cmake
2+
3+
# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead.
4+
5+
find_path(NCCL_INCLUDE_DIR
6+
NAMES nccl.h
7+
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
8+
PATH_SUFFIXES include
9+
)
10+
11+
find_library(NCCL_LIBRARY
12+
NAMES nccl
13+
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
14+
PATH_SUFFIXES lib lib64
15+
)
16+
17+
include(FindPackageHandleStandardArgs)
18+
find_package_handle_standard_args(NCCL
19+
DEFAULT_MSG
20+
NCCL_LIBRARY NCCL_INCLUDE_DIR
21+
)
22+
23+
if(NCCL_FOUND)
24+
set(NCCL_LIBRARIES ${NCCL_LIBRARY})
25+
set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR})
26+
27+
if(NOT TARGET NCCL::NCCL)
28+
add_library(NCCL::NCCL UNKNOWN IMPORTED)
29+
set_target_properties(NCCL::NCCL PROPERTIES
30+
IMPORTED_LOCATION "${NCCL_LIBRARY}"
31+
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}"
32+
)
33+
endif()
34+
endif()
35+
36+
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)

ggml/include/ggml-backend.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ extern "C" {
6868
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
6969

7070
// tensor copy between different backends
71-
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
71+
GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst);
7272

7373
//
7474
// Backend (stream)
@@ -83,13 +83,17 @@ extern "C" {
8383
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
8484
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);
8585

86-
GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
87-
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
86+
GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
87+
GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
88+
GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
89+
GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
8890

8991
// "offset" refers to the offset in tensor->data for setting/getting data
90-
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
91-
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
92-
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
92+
GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
93+
GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
94+
GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
95+
GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
96+
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
9397

9498
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
9599

@@ -109,7 +113,7 @@ extern "C" {
109113
// the copy is performed after all the currently queued operations in backend_src
110114
// backend_dst will wait for the copy to complete before performing other operations
111115
// automatic fallback to sync copy if async is not supported
112-
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
116+
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
113117

114118
GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
115119

@@ -135,7 +139,9 @@ extern "C" {
135139
// integrated GPU device using host memory
136140
GGML_BACKEND_DEVICE_TYPE_IGPU,
137141
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
138-
GGML_BACKEND_DEVICE_TYPE_ACCEL
142+
GGML_BACKEND_DEVICE_TYPE_ACCEL,
143+
// "meta" device wrapping multiple other devices for tensor parallelism
144+
GGML_BACKEND_DEVICE_TYPE_META,
139145
};
140146

141147
// functionality supported by the device
@@ -196,7 +202,9 @@ extern "C" {
196202

197203
// Common functions that may be obtained using ggml_backend_reg_get_proc_address
198204

199-
// Split buffer type for tensor parallelism
205+
// AllReduce operation for tensor parallelism (meta backend)
206+
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
207+
// Split buffer type for tensor parallelism (old)
200208
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
201209
// Set the number of threads for the backend
202210
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);

ggml/include/ggml-cuda.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
2727
// device buffer
2828
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
2929

30+
// conduct allreduce operation between devices
31+
GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
32+
3033
// split tensor buffer that splits matrices by rows across multiple devices
3134
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
3235

ggml/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ add_library(ggml-base
200200
ggml.cpp
201201
ggml-alloc.c
202202
ggml-backend.cpp
203+
ggml-backend-meta.cpp
203204
ggml-opt.cpp
204205
ggml-threading.cpp
205206
ggml-threading.h

ggml/src/ggml-alloc.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,6 +1236,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx,
12361236

12371237
ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
12381238
size_t nbytes_total = 0;
1239+
if (ggml_backend_buft_is_meta(buft)) {
1240+
return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft);
1241+
}
12391242
return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false);
12401243
}
12411244

ggml/src/ggml-backend-impl.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ extern "C" {
4949
void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
5050
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
5151
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
52+
// (optional) 2d data copies
53+
void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
54+
void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
55+
5256
// (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
5357
bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
5458
// clear the entire buffer
@@ -80,6 +84,20 @@ extern "C" {
8084
GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
8185
GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
8286

87+
//
88+
// Backend (meta)
89+
//
90+
91+
GGML_API bool ggml_backend_is_meta (ggml_backend_t backend);
92+
GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf);
93+
GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft);
94+
95+
GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend);
96+
GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index);
97+
98+
// temporary workaround to statically allocate tensors from a context in a deduplicated way:
99+
GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
100+
83101
//
84102
// Backend (stream)
85103
//
@@ -90,8 +108,10 @@ extern "C" {
90108
void (*free)(ggml_backend_t backend);
91109

92110
// (optional) asynchronous tensor data access
93-
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
94-
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
111+
void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
112+
void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
113+
void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
114+
void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
95115
bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
96116

97117
// (optional) complete all pending operations (required if the backend supports async operations)

0 commit comments

Comments
 (0)