11import socket
22import pytest
3-
43import paddle
54import paddle .distributed as dist_pp
65
76paddle .enable_compat ()
7+ from paddle .device .cuda .graphs import CUDAGraph
88import flashinfer .comm as comm
99
1010import os
@@ -36,15 +36,12 @@ def kernel(workspace_tensor, rank, world_size):
3636 residual_out = paddle .zeros (message_size , dtype = dtype , device = device )
3737 norm_out = paddle .zeros (message_size , dtype = dtype , device = device )
3838 quant_out = paddle .zeros (message_size , dtype = dtype , device = device )
39- scale_out = paddle .zeros (
40- message_size // 16 , dtype = dtype , device = device
41- ) # SF_VEC_SIZE = 16
39+ scale_out = paddle .zeros (message_size // 16 , dtype = dtype , device = device )
4240 rms_gamma = paddle .randn (hidden_dim , dtype = dtype , device = device )
4341 rms_eps = 1e-3
4442 scale_factor = paddle .tensor (0.5 , dtype = paddle .float32 , device = device )
4543
4644 # Run fusion operation
47- print ("Running fusion operation..." )
4845 comm .trtllm_allreduce_fusion (
4946 allreduce_in = allreduce_in ,
5047 world_size = world_size ,
@@ -69,7 +66,7 @@ def kernel(workspace_tensor, rank, world_size):
6966 layout_code = layout_code ,
7067 )
7168
72- paddle .cuda .synchronize ()
69+ # paddle.cuda.synchronize()
7370
7471 return allreduce_in , all_reduce_out
7572
@@ -91,6 +88,7 @@ def _run_simple_worker(world_size, rank, distributed_init_port):
9188 os .environ ["PADDLE_CURRENT_ENDPOINT" ] = (
9289 f"127.0.0.1:{ distributed_init_port + rank + 10 } "
9390 )
91+
9492 # Set NCCL related environment variables (optional but recommended)
9593 os .environ ["FLAGS_SYNC_NCCL_ALLREDUCE" ] = "1"
9694
@@ -117,7 +115,25 @@ def _run_simple_worker(world_size, rank, distributed_init_port):
117115 dist_pp .barrier (group = group_pp )
118116
119117 # Run fusion operation
120- allreduce_in_clone , all_reduce_out = kernel (workspace_tensor , rank , world_size )
118+ loop = 5
119+ s = paddle .cuda .Stream ()
120+ s .wait_stream (paddle .cuda .current_stream ())
121+ with paddle .cuda .stream (s ):
122+ for _ in range (loop ):
123+ allreduce_in_clone , all_reduce_out = kernel (
124+ workspace_tensor , rank , world_size
125+ )
126+
127+ g = CUDAGraph ()
128+ g .capture_begin ()
129+ for _ in range (loop ):
130+ allreduce_in_clone , all_reduce_out = kernel (
131+ workspace_tensor , rank , world_size
132+ )
133+ g .capture_end ()
134+
135+ g .replay ()
136+ paddle .cuda .synchronize ()
121137
122138 # # Calculate reference result
123139 dist_pp .all_reduce (allreduce_in_clone , group = group_pp )
0 commit comments