Skip to content

Commit 4bb9f30

Browse files
authored
enable all_reduce_fusion kernel + cudagraph function (#6)
1 parent 062990f commit 4bb9f30

1 file changed

Lines changed: 23 additions & 7 deletions

File tree

tests/comm/test_trtllm_allreduce_fusion_paddle.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import socket
22
import pytest
3-
43
import paddle
54
import paddle.distributed as dist_pp
65

76
paddle.enable_compat()
7+
from paddle.device.cuda.graphs import CUDAGraph
88
import flashinfer.comm as comm
99

1010
import os
@@ -36,15 +36,12 @@ def kernel(workspace_tensor, rank, world_size):
3636
residual_out = paddle.zeros(message_size, dtype=dtype, device=device)
3737
norm_out = paddle.zeros(message_size, dtype=dtype, device=device)
3838
quant_out = paddle.zeros(message_size, dtype=dtype, device=device)
39-
scale_out = paddle.zeros(
40-
message_size // 16, dtype=dtype, device=device
41-
) # SF_VEC_SIZE = 16
39+
scale_out = paddle.zeros(message_size // 16, dtype=dtype, device=device)
4240
rms_gamma = paddle.randn(hidden_dim, dtype=dtype, device=device)
4341
rms_eps = 1e-3
4442
scale_factor = paddle.tensor(0.5, dtype=paddle.float32, device=device)
4543

4644
# Run fusion operation
47-
print("Running fusion operation...")
4845
comm.trtllm_allreduce_fusion(
4946
allreduce_in=allreduce_in,
5047
world_size=world_size,
@@ -69,7 +66,7 @@ def kernel(workspace_tensor, rank, world_size):
6966
layout_code=layout_code,
7067
)
7168

72-
paddle.cuda.synchronize()
69+
# paddle.cuda.synchronize()
7370

7471
return allreduce_in, all_reduce_out
7572

@@ -91,6 +88,7 @@ def _run_simple_worker(world_size, rank, distributed_init_port):
9188
os.environ["PADDLE_CURRENT_ENDPOINT"] = (
9289
f"127.0.0.1:{distributed_init_port + rank + 10}"
9390
)
91+
9492
# Set NCCL related environment variables (optional but recommended)
9593
os.environ["FLAGS_SYNC_NCCL_ALLREDUCE"] = "1"
9694

@@ -117,7 +115,25 @@ def _run_simple_worker(world_size, rank, distributed_init_port):
117115
dist_pp.barrier(group=group_pp)
118116

119117
# Run fusion operation
120-
allreduce_in_clone, all_reduce_out = kernel(workspace_tensor, rank, world_size)
118+
loop = 5
119+
s = paddle.cuda.Stream()
120+
s.wait_stream(paddle.cuda.current_stream())
121+
with paddle.cuda.stream(s):
122+
for _ in range(loop):
123+
allreduce_in_clone, all_reduce_out = kernel(
124+
workspace_tensor, rank, world_size
125+
)
126+
127+
g = CUDAGraph()
128+
g.capture_begin()
129+
for _ in range(loop):
130+
allreduce_in_clone, all_reduce_out = kernel(
131+
workspace_tensor, rank, world_size
132+
)
133+
g.capture_end()
134+
135+
g.replay()
136+
paddle.cuda.synchronize()
121137

122138
# # Calculate reference result
123139
dist_pp.all_reduce(allreduce_in_clone, group=group_pp)

0 commit comments

Comments
 (0)