Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,13 +1585,40 @@ class at the server level, which is too granular for ModelRunner.
sampler_output = None
if not self.speculative_decoding:
sampler_output = self.sampler(logits, self.sampling_metadata)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
sampler_output.sampled_token_ids,
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
Comment on lines +1588 to +1593
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 src(root rank)的计算表达式在多处重复使用,后面 speculative 分支也同样重复。建议先用局部变量(例如 tp_src_rank = data_parallel_rank * tensor_parallel_size)保存,再传给 broadcast,避免复制粘贴带来的维护风险。

Copilot uses AI. Check for mistakes.
else:
sampler_output = self.sampler(
logits,
self.sampling_metadata,
self.model_config.max_model_len,
self.share_inputs,
)
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.broadcast(
self.share_inputs["accept_tokens"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
Comment on lines +1601 to +1606
Copy link

Copilot AI Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

speculative 分支里连续多次调用 broadcast,且每次都重复同一个 src 计算逻辑。建议复用同一个 tp_src_rank 变量,并考虑用一个 key 列表循环广播这些张量(accept_tokens/accept_num/step_idx/stop_flags),降低后续新增/修改字段时遗漏的概率。

Copilot uses AI. Check for mistakes.
paddle.distributed.broadcast(
self.share_inputs["accept_num"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["step_idx"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)
paddle.distributed.broadcast(
self.share_inputs["stop_flags"],
self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size,
group=self.parallel_config.tp_group,
)

prompt_logprobs_list = None
if not self.speculative_decoding:
Expand Down
Loading