diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a984e8788c4..eda4714a58b 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -1585,6 +1585,12 @@ class at the server level, which is too granular for ModelRunner. sampler_output = None if not self.speculative_decoding: sampler_output = self.sampler(logits, self.sampling_metadata) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + sampler_output.sampled_token_ids, + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) else: sampler_output = self.sampler( logits, @@ -1592,6 +1598,27 @@ class at the server level, which is too granular for ModelRunner. self.model_config.max_model_len, self.share_inputs, ) + if self.parallel_config.tensor_parallel_size > 1: + paddle.distributed.broadcast( + self.share_inputs["accept_tokens"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["accept_num"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["step_idx"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) + paddle.distributed.broadcast( + self.share_inputs["stop_flags"], + self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, + group=self.parallel_config.tp_group, + ) prompt_logprobs_list = None if not self.speculative_decoding: