diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..fd8ae99c79 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -900,10 +900,7 @@ def gmm( if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat: tokamax_group_sizes = group_sizes else: - tokamax_group_sizes = tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), - ) + tokamax_group_sizes = tokamax.RaggedDotGroupSizes(group_sizes, self.config.wi_tile_fwd_batch_seq) pad_length = self.config.wi_tile_fwd_batch_seq hs_shape = inputs.shape # pad length is the 1st dimension of tiling size in gmm call diff --git a/src/maxtext/models/deepseek_batchsplit.py b/src/maxtext/models/deepseek_batchsplit.py index 24ccf1c7b5..3a61d7476c 100644 --- a/src/maxtext/models/deepseek_batchsplit.py +++ b/src/maxtext/models/deepseek_batchsplit.py @@ -27,7 +27,6 @@ from maxtext.layers import attention_op from maxtext.layers import moe as moe_lib from maxtext.layers import quantizations -from maxtext.utils import max_utils import qwix.pallas as qpl import tokamax @@ -970,10 +969,7 @@ def gmm( output = tokamax.ragged_dot( lhs=inputs, rhs=kernel, - group_sizes=tokamax.RaggedDotGroupSizes( - group_sizes, - max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]), - ), + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, config.wi_tile_fwd_batch_seq), precision=jax.lax.Precision.DEFAULT, preferred_element_type=preferred_element_type, implementation="mosaic", diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 0b7edc5a9e..03e22a2119 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -1078,13 +1078,3 @@ def transformer_engine_context(): yield except (ImportError, AttributeError): yield - - -def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]: - """Generate group sizes for a given target m.""" - np.random.seed(0) - repr_val = np.random.uniform(size=(g,)) - repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val - repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m) - repr_val[0] += target_m - np.sum(repr_val) - return tuple(map(int, repr_val))