From 835feb46a6f38ceccfb71601597d5f3e853164cb Mon Sep 17 00:00:00 2001 From: Darisoy Date: Tue, 5 May 2026 17:18:43 +0000 Subject: [PATCH 1/3] Enable Tokamax GMM with autotuning fallback in MaxText --- src/maxtext/configs/types.py | 1 + src/maxtext/layers/moe.py | 29 ++++++++++++++----- src/maxtext/models/deepseek_batchsplit_fp8.py | 29 ++++++++++++++----- 3 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c35274cd24..ff686068dd 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -563,6 +563,7 @@ class Attention(BaseModel): False, description="Whether to use the Tokamax library for GMM kernel implementation.", ) + tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.") ragged_block_size: int = Field(256, description="Block size for ragged attention.") enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.") use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e23c3eba9f..3a6a1a919d 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -43,6 +43,8 @@ from qwix.contrib.sparsity import sparsity_module import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config set_xla_metadata = xla_metadata.set_xla_metadata @@ -1121,14 +1123,25 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a weight_gather_axes=weight_gather_axes, ) else: # tokamax (unquantized) - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax_group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=self.dtype, - implementation="mosaic", - ) + if self.config.tokamax_gmm_autotune: + with tokamax_config.autotuning_cache_miss_fallback("autotune"): + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation="mosaic", + ) + else: + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation="mosaic", + ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( lhs=inputs, diff --git a/src/maxtext/models/deepseek_batchsplit_fp8.py b/src/maxtext/models/deepseek_batchsplit_fp8.py index cef7c0646f..ee32fba127 100644 --- a/src/maxtext/models/deepseek_batchsplit_fp8.py +++ b/src/maxtext/models/deepseek_batchsplit_fp8.py @@ -29,6 +29,8 @@ from maxtext.layers import quantizations import qwix.pallas as qpl import tokamax +from tokamax import config as tokamax_config +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config @functools.partial( @@ -962,14 +964,25 @@ def gmm( qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0], ) else: - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=preferred_element_type, - implementation="mosaic", - ) + if config.tokamax_gmm_autotune: + with tokamax_config.autotuning_cache_miss_fallback("autotune"): + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation="mosaic", + ) + else: + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)), + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=preferred_element_type, + implementation="mosaic", + ) return output gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype) From 79b9fd01e1c80f274fbf2ddf5e31e9a9dd64975f Mon Sep 17 00:00:00 2001 From: Darisoy Date: Fri, 8 May 2026 23:17:10 +0000 Subject: [PATCH 2/3] Update get_tokamax_group_sizes in moe.py --- src/maxtext/layers/moe.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 3a6a1a919d..3c5f10b5ea 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -19,6 +19,7 @@ import functools import math import random +import os from typing import Iterable, Optional, Tuple, Union from aqt.jax.v2 import aqt_tensor as aqt @@ -1082,9 +1083,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): elif self.config.attention == "vllm_rpa": return group_sizes else: + ep = self.get_expert_parallelism_size() return tokamax.RaggedDotGroupSizes( group_sizes, - (inputs.shape[0] // kernel.shape[0],) * kernel.shape[0], + (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0], ) def get_quantization_dtypes(): @@ -1124,14 +1126,23 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a ) else: # tokamax (unquantized) if self.config.tokamax_gmm_autotune: - with tokamax_config.autotuning_cache_miss_fallback("autotune"): + cache_file = "tokamax_autotune_cache.json" + if os.path.exists(cache_file): + with open(cache_file, "r") as f: + autotune_result_json = f.read() + autotune_result = tokamax.AutotuningResult.loads(autotune_result_json) + autotune_context = autotune_result + else: + autotune_context = tokamax_config.autotuning_cache_miss_fallback("heuristics") + + with autotune_context: output = tokamax.ragged_dot( lhs=inputs, rhs=kernel, group_sizes=tokamax_group_sizes, precision=jax.lax.Precision.DEFAULT, preferred_element_type=self.dtype, - implementation="mosaic", + implementation=None, ) else: output = tokamax.ragged_dot( From cea66bc02436de7b45b98a7d3b14ab8ecc117844 Mon Sep 17 00:00:00 2001 From: darisoy Date: Mon, 11 May 2026 21:58:06 +0000 Subject: [PATCH 3/3] Add custom tiling logic for Tokamax GMM --- src/maxtext/layers/moe.py | 55 ++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 3c5f10b5ea..c4d576aefe 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -46,6 +46,12 @@ import tokamax from tokamax import config as tokamax_config from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config +from tokamax._src.autotuning.autotuner import AutotuningData +from tokamax._src.benchmarking import BenchmarkData +from tokamax._src.autotuning.api import AutotuningResult +import immutabledict +from tokamax._src.ops.ragged_dot import base +import functools set_xla_metadata = xla_metadata.set_xla_metadata @@ -1126,24 +1132,37 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a ) else: # tokamax (unquantized) if self.config.tokamax_gmm_autotune: - cache_file = "tokamax_autotune_cache.json" - if os.path.exists(cache_file): - with open(cache_file, "r") as f: - autotune_result_json = f.read() - autotune_result = tokamax.AutotuningResult.loads(autotune_result_json) - autotune_context = autotune_result - else: - autotune_context = tokamax_config.autotuning_cache_miss_fallback("heuristics") - - with autotune_context: - output = tokamax.ragged_dot( - lhs=inputs, - rhs=kernel, - group_sizes=tokamax_group_sizes, - precision=jax.lax.Precision.DEFAULT, - preferred_element_type=self.dtype, - implementation=None, - ) + # 1. Create configs from flags for backward pass + dlhs_config = Config( + tile_m=tiling[3], + tile_k=tiling[4], + tile_n=tiling[5], + ) + drhs_config = Config( + tile_m=tiling[6], + tile_k=tiling[7], + tile_n=tiling[8], + ) + + # 2. Create custom ops for backward pass + dlhs_op = PallasMosaicTpuRaggedDot(config=dlhs_config) + drhs_op = PallasMosaicTpuRaggedDot(config=drhs_config) + + # 3. Create custom vjp function + custom_vjp = functools.partial(base.vjp, dlhs_ragged_dot=dlhs_op, drhs_ragged_dot=drhs_op) + + # 4. Create forward op with custom vjp and tiling from flags + fwd_config = Config(tile_m=tiling[0], tile_k=tiling[1], tile_n=tiling[2]) + fwd_op = PallasMosaicTpuRaggedDot(config=fwd_config, vjp=custom_vjp) + + output = tokamax.ragged_dot( + lhs=inputs, + rhs=kernel, + group_sizes=tokamax_group_sizes, + precision=jax.lax.Precision.DEFAULT, + preferred_element_type=self.dtype, + implementation=(fwd_op,), # Pass the configured op directly! + ) else: output = tokamax.ragged_dot( lhs=inputs,