diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..c7f4b29643 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -897,7 +897,8 @@ def gmm( inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes ): # TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm - if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: + # TODO (b/493621965) Expert parallelism fails tokamax gmm when vLLM uses MaxText models + if (self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat) or self.config.rollout_expert_parallelism > 1: tokamax_group_sizes = group_sizes else: tokamax_group_sizes = tokamax.RaggedDotGroupSizes(