Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,21 @@ def forward(

if residual_input is None:
residual_out = x
use_allreduce_fused = (
self.enable_all_reduce_fusion
and self.tp_size > 1
and x.shape[0] <= 2048
and residual_input is not None
and current_platform.is_cuda()
)
if proxy_rmsnorm is None:
if current_platform.is_gcu():
if residual_input is None:
norm_out = rms_norm(x, self.weight, self.eps)
return norm_out.astype(x_dtype), residual_out
norm_out = self.norm_func(x, residual_input, self.weight, self.eps)
# enable trtllm all reduce fusion
elif self.enable_all_reduce_fusion and x.shape[0] <= 2048:
elif use_allreduce_fused:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps
)
Expand All @@ -276,9 +283,19 @@ def forward(
quant_min_bound=self.quant_min_bound,
)
else:
if residual_input is not None:
x = x + residual_input
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x
if use_allreduce_fused:
norm_out = flashinfer_allreduce_residual_rmsnorm(
fd_config=self.fd_config,
input_tensor=x,
residual=residual_input,
weight=self.weight,
eps=self.eps,
)
assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!"
else:
if residual_input is not None:
x = x + residual_input
norm_out = proxy_rmsnorm(x, self.weight, self.eps), x

out = norm_out[0].astype(x_dtype)
if residual_input is not None:
Expand Down
187 changes: 187 additions & 0 deletions tests/layers/trtllm_allreduce_rms_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,193 @@ def test_cleanup_workspace_function(self):
mock_manager.cleanup.assert_called_once()


class TestRMSNormProxyAllreduceFused(unittest.TestCase):
@classmethod
def setUpClass(cls):
# The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py
# has already done paddle.set_device + init_parallel_env, so we don't
# repeat that here. (unittest.main runs in the same process.)
cls.tp_size = dist.get_world_size()
cls.tp_rank = dist.get_rank()

def _make_fd_config(self, enable_fusion: bool):
"""Mock fd_config with the minimal attributes RMSNorm.__init__ touches."""
fd_config = Mock()
fd_config.parallel_config = Mock()
fd_config.parallel_config.tensor_parallel_size = self.tp_size
fd_config.parallel_config.tensor_parallel_rank = self.tp_rank
fd_config.parallel_config.tp_group = dist.get_group()
fd_config.parallel_config.expert_parallel_size = 1
fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion
fd_config.parallel_config.use_sequence_parallel_moe = False
fd_config.model_config = Mock()
fd_config.model_config.moe_layer_start_index = -1
fd_config.quant_config = None
return fd_config

def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1):
"""Build a real RMSNorm whose enable_all_reduce_fusion resolves to
`enable_fusion` (use post_attention_layernorm prefix to ensure the
prefix-match in __init__ passes)."""
from fastdeploy.model_executor.layers.normalization import RMSNorm

fd_config = self._make_fd_config(enable_fusion=enable_fusion)
norm = RMSNorm(
fd_config=fd_config,
hidden_size=hidden_size,
eps=1e-6,
prefix=f"model.layers.{layer_id}.post_attention_layernorm",
layer_id=layer_id,
dtype="bfloat16",
)
# Initialize weight to a known reproducible value (constant=1.0 by default).
with paddle.no_grad():
paddle.seed(2024)
new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16)
dist.broadcast(new_w, src=0)
norm.weight.set_value(new_w)
return norm

@staticmethod
def _proxy_rmsnorm_fn(x, weight, eps):
"""Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula
in fp32 to keep reference numerics clean."""
x_fp32 = x.astype("float32")
var = x_fp32.pow(2).mean(axis=-1, keepdim=True)
out = x_fp32 * paddle.rsqrt(var + eps)
out = out * weight.astype("float32")
return out.astype(x.dtype)

def _reference(self, x_partial, residual, weight, eps):
"""Manual: all_reduce(x_partial) + residual, then standard RMSNorm.
Mirrors what proxy path WOULD produce after explicit allreduce+add."""
x = x_partial.clone()
dist.all_reduce(x, op=dist.ReduceOp.SUM)
residual_out = x + residual
norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps)
return norm_out, residual_out

def _make_inputs(self, token_num, hidden_size, seed=123):
"""Each rank gets a different x_partial (simulates RowParallelLinear's
un-reduced output); residual is identical across ranks."""
paddle.seed(seed + self.tp_rank * 7919)
x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1
paddle.seed(seed + 99)
residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16)
dist.broadcast(residual, src=0)
return x_partial, residual

def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""):
a32 = a.astype("float32").numpy()
b32 = b.astype("float32").numpy()
np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg)

# ---------- Tests ----------

def test_proxy_path_takes_fused_branch(self):
"""fusion=on, tp>1, shape<=2048, residual!=None
-> proxy branch picks flashinfer_allreduce_residual_rmsnorm.
Verify by patching the symbol and asserting it was called.
"""
if self.tp_size < 2:
self.skipTest("Requires tp_size >= 2")
hidden = 512
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
self.assertTrue(norm.enable_all_reduce_fusion)
x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden)

# Patch within the normalization module's namespace.
with patch(
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm",
wraps=__import__(
"fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"]
).flashinfer_allreduce_residual_rmsnorm,
) as spy:
out, res = norm.forward(
x_partial.clone(),
residual_input=residual.clone(),
proxy_rmsnorm=self._proxy_rmsnorm_fn,
)
spy.assert_called_once()

# Numerics: must match reference (allreduce + add + std rmsnorm).
ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps)
self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch")
self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch")

def test_proxy_path_falls_back_when_fusion_disabled(self):
"""fusion=off -> proxy branch must call proxy_rmsnorm directly,
no fused allreduce path used. Input is treated as already-reduced."""
if self.tp_size < 2:
self.skipTest("Requires tp_size >= 2")
hidden = 512
norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden)
self.assertFalse(norm.enable_all_reduce_fusion)

# Each rank uses the SAME x (already-reduced) — that's the contract
# when fusion is off (RowParallelLinear has done its own allreduce).
paddle.seed(777)
x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1
dist.broadcast(x, src=0)
residual = paddle.randn([64, hidden], dtype=paddle.bfloat16)
dist.broadcast(residual, src=0)

proxy_called = {"n": 0}

def proxy_spy(_x, _w, _eps):
proxy_called["n"] += 1
return self._proxy_rmsnorm_fn(_x, _w, _eps)

with patch(
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
) as fused_spy:
out, res = norm.forward(
x.clone(),
residual_input=residual.clone(),
proxy_rmsnorm=proxy_spy,
)
fused_spy.assert_not_called()

self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once")

# Reference: x is already full -> just add + rmsnorm, no allreduce.
residual_full = x + residual
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch")
self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch")

def test_proxy_path_falls_back_when_token_too_large(self):
"""fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused;
in this regime upstream RowParallelLinear didn't skip its own
all-reduce, so x is already full and proxy_rmsnorm is invoked directly."""
if self.tp_size < 2:
self.skipTest("Requires tp_size >= 2")
hidden = 256
norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden)
# shape[0] > 2048 forces use_allreduce_fused=False
token_num = 2049
paddle.seed(555)
x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1
dist.broadcast(x, src=0)
residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16)
dist.broadcast(residual, src=0)

with patch(
"fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm"
) as fused_spy:
out, res = norm.forward(
x.clone(),
residual_input=residual.clone(),
proxy_rmsnorm=self._proxy_rmsnorm_fn,
)
fused_spy.assert_not_called()

residual_full = x + residual
ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps)
self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch")
self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch")


if __name__ == "__main__":
"""Run tests directly (called by subprocess after distributed launch)"""
unittest.main(verbosity=2)
Loading