diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c35274cd24..ff686068dd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -563,6 +563,7 @@ class Attention(BaseModel): False, description="Whether to use the Tokamax library for GMM kernel implementation.", ) + tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.") ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e23c3eba9f..c4d576aefe 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -19,6 +19,7 @@ import functools import math import random +import os from typing import Iterable, Optional, Tuple, Union from aqt.jax.v2 import aqt_tensor as aqt @@ -43,6 +44,14 @@ from qwix.contrib.sparsity import sparsity_module import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config +from tokamax._src.autotuning.autotuner import AutotuningData +from tokamax._src.benchmarking import BenchmarkData +from tokamax._src.autotuning.api import AutotuningResult +import immutabledict +from tokamax._src.ops.ragged_dot import base +import functools set_xla_metadata = xla_metadata.set_xla_metadata @@ -1080,9 +1089,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): elif self.config.attention == "vllm_rpa": return group_sizes else: + ep = self.get_expert_parallelism_size() return tokamax.RaggedDotGroupSizes( group_sizes, - (inputs.shape[0] // kernel.shape[0],) * kernel.shape[0], + (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0], ) def get_quantization_dtypes(): @@ -1121,14 +1131,47 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a weight_gather_axes=weight_gather_axes, ) else: # tokamax (unquantized) - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax_group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=self.dtype, - implementation="mosaic", - ) + if self.config.tokamax_gmm_autotune: + # 1. Create configs from flags for backward pass + dlhs_config = Config( + tile_m=tiling[3], + tile_k=tiling[4], + tile_n=tiling[5], + ) + drhs_config = Config( + tile_m=tiling[6], + tile_k=tiling[7], + tile_n=tiling[8], + ) + + # 2. Create custom ops for backward pass + dlhs_op = PallasMosaicTpuRaggedDot(config=dlhs_config) + drhs_op = PallasMosaicTpuRaggedDot(config=drhs_config) + + # 3. Create custom vjp function + custom_vjp = functools.partial(base.vjp, dlhs_ragged_dot=dlhs_op, drhs_ragged_dot=drhs_op) + + # 4. Create forward op with custom vjp and tiling from flags + fwd_config = Config(tile_m=tiling[0], tile_k=tiling[1], tile_n=tiling[2]) + fwd_op = PallasMosaicTpuRaggedDot(config=fwd_config, vjp=custom_vjp) + + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation=(fwd_op,), # Pass the configured op directly! + ) + else: + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation="mosaic", + ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( lhs=inputs, diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index cef7c0646f..ee32fba127 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -29,6 +29,8 @@ from maxtext.layers import quantizations import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config @functools.partial( @@ -962,14 +964,25 @@ def gmm( qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0], ) else: - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=preferred_element_type, - implementation="mosaic", - ) + if config.tokamax_gmm_autotune: + with tokamax_config.autotuning_cache_miss_fallback("autotune"): + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation="mosaic", + ) + else: + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation="mosaic", + ) return output gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)