Skip to content

Commit 062990f

Browse files
authored
support all reduce fusion kernel (#5)
1 parent 5f2f442 commit 062990f

4 files changed

Lines changed: 188 additions & 11 deletions

File tree

flashinfer/comm/cuda_ipc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
import ctypes
1818
from dataclasses import dataclass
1919
from typing import Any, Dict, List, Optional
20-
2120
import torch.distributed as dist
22-
from torch.distributed import ProcessGroup
21+
from paddle.base.core import ProcessGroup
2322

2423
# NOTE(Zihao): we should use cuda-python instead of ctypes cuda runtime bindings.
2524
# However, cuda-python's API is not stable yet, so we use ctypes bindings instead.
@@ -205,11 +204,18 @@ def create_shared_buffer(
205204
handle = cudart.cudaIpcGetMemHandle(pointer)
206205
if group is None:
207206
group = dist.group.WORLD
208-
world_size = dist.get_world_size(group=group)
207+
# world_size = dist.get_world_size(group=group)
209208
rank = dist.get_rank(group=group)
210-
handles = [None] * world_size
211-
dist.all_gather_object(handles, handle, group=group)
212-
handles = [None] * world_size
209+
# handles = [None] * world_size
210+
# dist.all_gather_object(handles, handle, group=group)
211+
# handles = [None] * world_size
212+
# dist.all_gather_object(handles, handle, group=group)
213+
214+
# The behavior of the paddle framework and torch framework is inconsistent,
215+
# so the following code is used instead
216+
# TODO(bingoo): The PR(https://github.com/PaddlePaddle/Paddle/pull/77152)
217+
# has been fixed.
218+
handles = [None]
213219
dist.all_gather_object(handles, handle, group=group)
214220

215221
pointers: List[int] = []

flashinfer/comm/nvshmem_allreduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional
1818

1919
import torch
20-
from torch.distributed import ProcessGroup
20+
from paddle.base.core import ProcessGroup
2121

2222
from .nvshmem import get_nvshmem_module
2323

flashinfer/comm/trtllm_ar.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import torch
2424
import torch.distributed as dist
25-
from torch.distributed import ProcessGroup
25+
from paddle.base.core import ProcessGroup
2626

2727
from ..jit.comm import gen_trtllm_comm_module
2828
from ..utils import register_custom_op, round_up
@@ -602,9 +602,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
602602
print(f"Rank {tp_rank} workspace[{i}] {hex(workspace[i])}")
603603

604604
# Store workspace pointers in device tensor
605-
workspace_tensor = torch.tensor(
606-
workspace, dtype=torch.int64, device=torch.device("cuda")
607-
)
605+
# workspace_tensor = torch.tensor(
606+
# workspace, dtype=torch.int64, device=torch.device("cuda")
607+
# )
608+
609+
# There is a bug in the paddle framework when device="CUDA".
610+
# Currently, the bug is being avoided by changing the source code.
611+
workspace_tensor = torch.tensor(workspace, dtype=torch.int64)
608612

609613
dist.barrier(group=group) # must sync after create_workspace
610614

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import socket
2+
import pytest
3+
4+
import paddle
5+
import paddle.distributed as dist_pp
6+
7+
paddle.enable_compat()
8+
import flashinfer.comm as comm
9+
10+
import os
11+
import numpy as np
12+
13+
# test parameters
14+
token_num = 128
15+
hidden_dim = 1024
16+
dtype = paddle.float16
17+
pattern_code = comm.AllReduceFusionPattern.kAllReduce
18+
layout_code = comm.QuantizationSFLayout.LINEAR
19+
launch_with_pdl = False
20+
use_oneshot = True
21+
trigger_completion_at_end = True
22+
fp32_acc = False
23+
24+
25+
def kernel(workspace_tensor, rank, world_size):
26+
device = f"cuda:{rank}"
27+
message_size = token_num * hidden_dim
28+
dtype = paddle.float16
29+
# Create input data
30+
allreduce_in = paddle.randn(message_size, dtype=dtype, device=device)
31+
# allreduce_in_clone = allreduce_in.clone()
32+
all_reduce_out = paddle.zeros(message_size, dtype=dtype, device=device)
33+
34+
# Add missing required parameters
35+
residual_in = paddle.randn(message_size, dtype=dtype, device=device)
36+
residual_out = paddle.zeros(message_size, dtype=dtype, device=device)
37+
norm_out = paddle.zeros(message_size, dtype=dtype, device=device)
38+
quant_out = paddle.zeros(message_size, dtype=dtype, device=device)
39+
scale_out = paddle.zeros(
40+
message_size // 16, dtype=dtype, device=device
41+
) # SF_VEC_SIZE = 16
42+
rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device)
43+
rms_eps = 1e-3
44+
scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device)
45+
46+
# Run fusion operation
47+
print("Running fusion operation...")
48+
comm.trtllm_allreduce_fusion(
49+
allreduce_in=allreduce_in,
50+
world_size=world_size,
51+
world_rank=rank,
52+
token_num=token_num,
53+
hidden_dim=hidden_dim,
54+
workspace_ptrs=workspace_tensor,
55+
launch_with_pdl=launch_with_pdl,
56+
use_oneshot=use_oneshot,
57+
trigger_completion_at_end=trigger_completion_at_end,
58+
fp32_acc=fp32_acc,
59+
pattern_code=pattern_code,
60+
allreduce_out=all_reduce_out,
61+
residual_in=residual_in,
62+
residual_out=residual_out,
63+
norm_out=norm_out,
64+
quant_out=quant_out,
65+
scale_out=scale_out,
66+
rms_gamma=rms_gamma,
67+
rms_eps=rms_eps,
68+
scale_factor=scale_factor,
69+
layout_code=layout_code,
70+
)
71+
72+
paddle.cuda.synchronize()
73+
74+
return allreduce_in, all_reduce_out
75+
76+
77+
def _run_simple_worker(world_size, rank, distributed_init_port):
78+
# Create workspace
79+
# paddle.compat.enable_torch_proxy()
80+
# Set all required environment variables
81+
os.environ["FLAGS_SELECTED_GPUS"] = str(rank) # Key: set GPU ID
82+
os.environ["PADDLE_TRAINER_ID"] = str(rank)
83+
os.environ["PADDLE_TRAINERS_NUM"] = str(world_size)
84+
os.environ["PADDLE_RANK_IN_NODE"] = str(rank)
85+
86+
# Build endpoint list
87+
endpoints = ",".join(
88+
[f"127.0.0.1:{distributed_init_port + i + 10}" for i in range(world_size)]
89+
)
90+
os.environ["PADDLE_TRAINER_ENDPOINTS"] = endpoints
91+
os.environ["PADDLE_CURRENT_ENDPOINT"] = (
92+
f"127.0.0.1:{distributed_init_port + rank + 10}"
93+
)
94+
# Set NCCL related environment variables (optional but recommended)
95+
os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1"
96+
97+
# Set device
98+
paddle.set_device(f"gpu:{rank}")
99+
100+
# Initialize distributed environment
101+
dist_pp.init_parallel_env()
102+
group_pp = dist_pp.get_group()
103+
104+
try:
105+
# Create workspace
106+
ipc_handles, workspace_tensor = (
107+
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
108+
rank,
109+
world_size,
110+
token_num,
111+
hidden_dim,
112+
group=group_pp,
113+
use_fp32_lamport=False,
114+
)
115+
)
116+
117+
dist_pp.barrier(group=group_pp)
118+
119+
# Run fusion operation
120+
allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size)
121+
122+
# # Calculate reference result
123+
dist_pp.all_reduce(allreduce_in_clone, group=group_pp)
124+
ref_allreduce_out = allreduce_in_clone.clone()
125+
126+
# # Verify results
127+
tolerance = 8e-2
128+
np.testing.assert_allclose(
129+
all_reduce_out.numpy(), ref_allreduce_out.numpy(), atol=tolerance, rtol=1e-2
130+
)
131+
132+
print(f"Rank {rank}: Test passed!")
133+
134+
finally:
135+
dist_pp.barrier(group=group_pp)
136+
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group_pp)
137+
dist_pp.destroy_process_group(group=group_pp)
138+
139+
140+
def get_open_port() -> int:
141+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
142+
s.bind(("127.0.0.1", 0))
143+
return s.getsockname()[1]
144+
145+
146+
def test_trtllm_allreduce_fusion_simple():
147+
# Fixed test parameters
148+
world_size = 2
149+
150+
paddle.manual_seed(42)
151+
paddle.cuda.manual_seed_all(42)
152+
153+
available_gpus = paddle.cuda.device_count()
154+
if world_size > available_gpus:
155+
pytest.skip(f"Requires {world_size} GPUs, but only {available_gpus} available")
156+
157+
distributed_init_port = get_open_port()
158+
rank = dist_pp.get_rank()
159+
_run_simple_worker(world_size, rank, distributed_init_port)
160+
161+
print("Simple allreduce fusion test: passed")
162+
163+
164+
# test cmd: python -m paddle.distributed.launch --log_dir=log --devices=0,1
165+
# ./test_torch_pp_launch.py
166+
if __name__ == "__main__":
167+
test_trtllm_allreduce_fusion_simple()

0 commit comments

Comments
 (0)