diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index a89b5c139c7..326c6694498 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -243,6 +243,13 @@ def forward( if residual_input is None: residual_out = x + use_allreduce_fused = ( + self.enable_all_reduce_fusion + and self.tp_size > 1 + and x.shape[0] <= 2048 + and residual_input is not None + and current_platform.is_cuda() + ) if proxy_rmsnorm is None: if current_platform.is_gcu(): if residual_input is None: @@ -250,7 +257,7 @@ def forward( return norm_out.astype(x_dtype), residual_out norm_out = self.norm_func(x, residual_input, self.weight, self.eps) # enable trtllm all reduce fusion - elif self.enable_all_reduce_fusion and x.shape[0] <= 2048: + elif use_allreduce_fused: norm_out = flashinfer_allreduce_residual_rmsnorm( fd_config=self.fd_config, input_tensor=x, residual=residual_input, weight=self.weight, eps=self.eps ) @@ -276,9 +283,19 @@ def forward( quant_min_bound=self.quant_min_bound, ) else: - if residual_input is not None: - x = x + residual_input - norm_out = proxy_rmsnorm(x, self.weight, self.eps), x + if use_allreduce_fused: + norm_out = flashinfer_allreduce_residual_rmsnorm( + fd_config=self.fd_config, + input_tensor=x, + residual=residual_input, + weight=self.weight, + eps=self.eps, + ) + assert norm_out[0] is not None, "Trtllm-all-reduce fusion failed!" + else: + if residual_input is not None: + x = x + residual_input + norm_out = proxy_rmsnorm(x, self.weight, self.eps), x out = norm_out[0].astype(x_dtype) if residual_input is not None: diff --git a/tests/layers/trtllm_allreduce_rms_fusion.py b/tests/layers/trtllm_allreduce_rms_fusion.py index e762130f6b7..0b3069b0953 100644 --- a/tests/layers/trtllm_allreduce_rms_fusion.py +++ b/tests/layers/trtllm_allreduce_rms_fusion.py @@ -543,6 +543,193 @@ def test_cleanup_workspace_function(self): mock_manager.cleanup.assert_called_once() +class TestRMSNormProxyAllreduceFused(unittest.TestCase): + @classmethod + def setUpClass(cls): + # The outer test_run_distributed in test_trtllm_allreduce_rms_fusion.py + # has already done paddle.set_device + init_parallel_env, so we don't + # repeat that here. (unittest.main runs in the same process.) + cls.tp_size = dist.get_world_size() + cls.tp_rank = dist.get_rank() + + def _make_fd_config(self, enable_fusion: bool): + """Mock fd_config with the minimal attributes RMSNorm.__init__ touches.""" + fd_config = Mock() + fd_config.parallel_config = Mock() + fd_config.parallel_config.tensor_parallel_size = self.tp_size + fd_config.parallel_config.tensor_parallel_rank = self.tp_rank + fd_config.parallel_config.tp_group = dist.get_group() + fd_config.parallel_config.expert_parallel_size = 1 + fd_config.parallel_config.enable_flashinfer_allreduce_fusion = enable_fusion + fd_config.parallel_config.use_sequence_parallel_moe = False + fd_config.model_config = Mock() + fd_config.model_config.moe_layer_start_index = -1 + fd_config.quant_config = None + return fd_config + + def _build_rmsnorm(self, enable_fusion: bool, hidden_size: int, layer_id: int = 1): + """Build a real RMSNorm whose enable_all_reduce_fusion resolves to + `enable_fusion` (use post_attention_layernorm prefix to ensure the + prefix-match in __init__ passes).""" + from fastdeploy.model_executor.layers.normalization import RMSNorm + + fd_config = self._make_fd_config(enable_fusion=enable_fusion) + norm = RMSNorm( + fd_config=fd_config, + hidden_size=hidden_size, + eps=1e-6, + prefix=f"model.layers.{layer_id}.post_attention_layernorm", + layer_id=layer_id, + dtype="bfloat16", + ) + # Initialize weight to a known reproducible value (constant=1.0 by default). + with paddle.no_grad(): + paddle.seed(2024) + new_w = paddle.randn([hidden_size], dtype=paddle.bfloat16) + dist.broadcast(new_w, src=0) + norm.weight.set_value(new_w) + return norm + + @staticmethod + def _proxy_rmsnorm_fn(x, weight, eps): + """Stand-in for phi rmsnorm used as proxy_rmsnorm — standard formula + in fp32 to keep reference numerics clean.""" + x_fp32 = x.astype("float32") + var = x_fp32.pow(2).mean(axis=-1, keepdim=True) + out = x_fp32 * paddle.rsqrt(var + eps) + out = out * weight.astype("float32") + return out.astype(x.dtype) + + def _reference(self, x_partial, residual, weight, eps): + """Manual: all_reduce(x_partial) + residual, then standard RMSNorm. + Mirrors what proxy path WOULD produce after explicit allreduce+add.""" + x = x_partial.clone() + dist.all_reduce(x, op=dist.ReduceOp.SUM) + residual_out = x + residual + norm_out = self._proxy_rmsnorm_fn(residual_out, weight, eps) + return norm_out, residual_out + + def _make_inputs(self, token_num, hidden_size, seed=123): + """Each rank gets a different x_partial (simulates RowParallelLinear's + un-reduced output); residual is identical across ranks.""" + paddle.seed(seed + self.tp_rank * 7919) + x_partial = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) * 0.1 + paddle.seed(seed + 99) + residual = paddle.randn([token_num, hidden_size], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + return x_partial, residual + + def _assert_close_bf16(self, a, b, rtol=5e-2, atol=5e-2, msg=""): + a32 = a.astype("float32").numpy() + b32 = b.astype("float32").numpy() + np.testing.assert_allclose(a32, b32, rtol=rtol, atol=atol, err_msg=msg) + + # ---------- Tests ---------- + + def test_proxy_path_takes_fused_branch(self): + """fusion=on, tp>1, shape<=2048, residual!=None + -> proxy branch picks flashinfer_allreduce_residual_rmsnorm. + Verify by patching the symbol and asserting it was called. + """ + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + self.assertTrue(norm.enable_all_reduce_fusion) + x_partial, residual = self._make_inputs(token_num=64, hidden_size=hidden) + + # Patch within the normalization module's namespace. + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm", + wraps=__import__( + "fastdeploy.model_executor.layers.normalization", fromlist=["flashinfer_allreduce_residual_rmsnorm"] + ).flashinfer_allreduce_residual_rmsnorm, + ) as spy: + out, res = norm.forward( + x_partial.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + spy.assert_called_once() + + # Numerics: must match reference (allreduce + add + std rmsnorm). + ref_norm, ref_res = self._reference(x_partial, residual, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="proxy fused-branch norm output mismatch") + self._assert_close_bf16(res, ref_res, msg="proxy fused-branch residual mismatch") + + def test_proxy_path_falls_back_when_fusion_disabled(self): + """fusion=off -> proxy branch must call proxy_rmsnorm directly, + no fused allreduce path used. Input is treated as already-reduced.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 512 + norm = self._build_rmsnorm(enable_fusion=False, hidden_size=hidden) + self.assertFalse(norm.enable_all_reduce_fusion) + + # Each rank uses the SAME x (already-reduced) — that's the contract + # when fusion is off (RowParallelLinear has done its own allreduce). + paddle.seed(777) + x = paddle.randn([64, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([64, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + proxy_called = {"n": 0} + + def proxy_spy(_x, _w, _eps): + proxy_called["n"] += 1 + return self._proxy_rmsnorm_fn(_x, _w, _eps) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=proxy_spy, + ) + fused_spy.assert_not_called() + + self.assertEqual(proxy_called["n"], 1, "proxy_rmsnorm must be invoked exactly once") + + # Reference: x is already full -> just add + rmsnorm, no allreduce. + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="fallback norm output mismatch") + self._assert_close_bf16(res, residual_full, msg="fallback residual mismatch") + + def test_proxy_path_falls_back_when_token_too_large(self): + """fusion=on but shape[0] > 2048 -> proxy branch must NOT call fused; + in this regime upstream RowParallelLinear didn't skip its own + all-reduce, so x is already full and proxy_rmsnorm is invoked directly.""" + if self.tp_size < 2: + self.skipTest("Requires tp_size >= 2") + hidden = 256 + norm = self._build_rmsnorm(enable_fusion=True, hidden_size=hidden) + # shape[0] > 2048 forces use_allreduce_fused=False + token_num = 2049 + paddle.seed(555) + x = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) * 0.1 + dist.broadcast(x, src=0) + residual = paddle.randn([token_num, hidden], dtype=paddle.bfloat16) + dist.broadcast(residual, src=0) + + with patch( + "fastdeploy.model_executor.layers.normalization.flashinfer_allreduce_residual_rmsnorm" + ) as fused_spy: + out, res = norm.forward( + x.clone(), + residual_input=residual.clone(), + proxy_rmsnorm=self._proxy_rmsnorm_fn, + ) + fused_spy.assert_not_called() + + residual_full = x + residual + ref_norm = self._proxy_rmsnorm_fn(residual_full, norm.weight, norm.eps) + self._assert_close_bf16(out, ref_norm, msg="large-shape fallback norm mismatch") + self._assert_close_bf16(res, residual_full, msg="large-shape fallback residual mismatch") + + if __name__ == "__main__": """Run tests directly (called by subprocess after distributed launch)""" unittest.main(verbosity=2)