Skip to content

Commit d69bdb2

Browse files
committed
patch for paddle
1 parent 3198c36 commit d69bdb2

7 files changed

Lines changed: 30 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ name = "sonic-moe"
77
dynamic = ["version"]
88
requires-python = ">=3.12"
99
dependencies = [
10-
"nvidia-cutlass-dsl==4.4.0",
10+
"nvidia-cutlass-dsl==4.4.1",
1111
"torch>=2.7.1,<=2.9.1",
12-
"quack-kernels==0.2.5",
12+
"quack-kernels @ git+https://github.com/PFCCLab/quack.git@954fe1638beef1b0a3d2fd463cfb6372c0b026cf",
1313
"ninja"
1414
]
1515

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# ********************************************************************************
44

55
# torch>=2.7.1
6-
nvidia-cutlass-dsl==4.3.0
6+
nvidia-cutlass-dsl==4.4.1
77
# quack-kernels @ git+https://github.com/Dao-AILab/quack.git@3d0ab3ec2164749caac8f269f771e66a40efd2de
8-
quack-kernels @ git+https://github.com/PFCCLab/quack.git@12783f5ddbb20a2dd65bf88813db644c0e227f93
8+
quack-kernels @ git+https://github.com/PFCCLab/quack.git@954fe1638beef1b0a3d2fd463cfb6372c0b026cf
99
pytest
1010
parameterized
1111
ninja

sonicmoe/count_cumsum/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def count_cumsum(x: torch.Tensor, E: int, do_cumsum: bool = True) -> torch.Tenso
2222

2323
count_output = torch.empty(E, dtype=torch.int32, device=x.device)
2424
cumsum_output = torch.empty(E, dtype=torch.int32, device=x.device) if do_cumsum else None
25-
stream = torch.cuda.current_stream(x.device).cuda_stream
25+
stream = torch.cuda.current_stream(x.device).stream_base.raw_stream
2626

2727
count_cumsum_cuda(x=x, count_output=count_output, cumsum_output=cumsum_output, stream=stream)
2828

sonicmoe/functional/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def forward(
167167
num_activated_expert_per_token_offset,
168168
)
169169

170+
ctx.has_num_activated_expert_per_token_offset = num_activated_expert_per_token_offset is not None
170171
ctx.mark_non_differentiable(y1)
171172
ctx.set_materialize_grads(False)
172173

@@ -260,7 +261,10 @@ def backward(ctx, _: None, dz: torch.Tensor):
260261
grads.extend([dx_reduced, dw1])
261262
if db1 is not None:
262263
grads.append(db1)
263-
grads.extend([None] * 5)
264+
if ctx.has_num_activated_expert_per_token_offset:
265+
grads.extend([None] * 5)
266+
else:
267+
grads.extend([None] * 4)
264268
return tuple(grads)
265269

266270

@@ -280,7 +284,7 @@ def forward(
280284
x_gather_idx: torch.Tensor,
281285
s_scatter_idx: torch.Tensor,
282286
s_reverse_scatter_idx: torch.Tensor,
283-
num_activated_expert_per_token_offset: torch.Tensor,
287+
num_activated_expert_per_token_offset: torch.Tensor | None,
284288
is_varlen_K: bool,
285289
activation_type: ActivationType,
286290
) -> torch.Tensor:
@@ -335,6 +339,7 @@ def forward(
335339
s_scatter_idx,
336340
s_reverse_scatter_idx,
337341
)
342+
ctx.has_num_activated_expert_per_token_offset = num_activated_expert_per_token_offset is None
338343

339344
return o
340345

@@ -436,7 +441,12 @@ def backward(ctx, dout: torch.Tensor):
436441
grads.extend([None, dz, dw2])
437442
if db2 is not None:
438443
grads.append(db2)
439-
grads.extend([ds, *[None] * 5])
444+
445+
if ctx.has_num_activated_expert_per_token_offset:
446+
grads.extend([ds, *[None] * 4])
447+
else:
448+
grads.extend([ds, *[None] * 5])
449+
440450
return tuple(grads)
441451

442452

sonicmoe/functional/grouped_gemm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,6 @@ def load_A_gather(
592592
mA_cur_copy = cute.make_tensor(tPrAptr, ((copy_elems_per_thr_load, 1), 1))
593593

594594
cute.copy(A_g2s_thr_copy, mA_cur_copy, tAsA[None, None, i])
595-
else:
596-
zero_frag = cute.make_fragment_like(tAsA[None, None, i])
597-
zero_frag.fill(0.0)
598-
cute.basic_copy(zero_frag, tAsA[None, None, i])
599595

600596
else:
601597
MIdx = tmAIdx[i]

tests/moe_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class MoETest(TestCommons):
4747
[paddle.device("cuda")],
4848
[torch.bfloat16],
4949
[
50-
((16384 + 512) * 16, 512, 512, 128, 8)
5150
(8192, 768, 256, 128, 8),
5251
(8192, 768, 512, 64, 4),
5352
(8192, 768, 1024, 32, 2),

tests/test_commons.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
import numpy as np
1212
import torch
1313
import torch.nn as nn
14-
from torch.testing import assert_close
14+
15+
16+
def assert_close(actual, expected, rtol=None, atol=None, **kwargs):
17+
"""Drop-in replacement for torch.testing.assert_close,
18+
since paddle.testing is not available in the current Paddle version."""
19+
a = actual.detach().cpu().float().numpy()
20+
b = expected.detach().cpu().float().numpy()
21+
if rtol is None:
22+
rtol = 1.3e-6
23+
if atol is None:
24+
atol = 1e-5
25+
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
1526

1627

1728
class TestCommons(TestCase):

0 commit comments

Comments
 (0)