Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ class Attention(BaseModel):
False,
description="Whether to use the Tokamax library for GMM kernel implementation.",
)
tokamax_gmm_autotune: bool = Field(False, description="Whether to use tokamax auto-tuner for GMM.")
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
Expand Down
61 changes: 52 additions & 9 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import math
import random
import os
from typing import Iterable, Optional, Tuple, Union

from aqt.jax.v2 import aqt_tensor as aqt
Expand All @@ -43,6 +44,14 @@
from qwix.contrib.sparsity import sparsity_module
import qwix.pallas as qpl
import tokamax
from tokamax import config as tokamax_config
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config
from tokamax._src.autotuning.autotuner import AutotuningData
from tokamax._src.benchmarking import BenchmarkData
from tokamax._src.autotuning.api import AutotuningResult
import immutabledict
from tokamax._src.ops.ragged_dot import base
import functools

set_xla_metadata = xla_metadata.set_xla_metadata

Expand Down Expand Up @@ -1080,9 +1089,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
elif self.config.attention == "vllm_rpa":
return group_sizes
else:
ep = self.get_expert_parallelism_size()
return tokamax.RaggedDotGroupSizes(
group_sizes,
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
(inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0],
)

def get_quantization_dtypes():
Expand Down Expand Up @@ -1121,14 +1131,47 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
weight_gather_axes=weight_gather_axes,
)
else: # tokamax (unquantized)
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
if self.config.tokamax_gmm_autotune:
# 1. Create configs from flags for backward pass
dlhs_config = Config(
tile_m=tiling[3],
tile_k=tiling[4],
tile_n=tiling[5],
)
drhs_config = Config(
tile_m=tiling[6],
tile_k=tiling[7],
tile_n=tiling[8],
)

# 2. Create custom ops for backward pass
dlhs_op = PallasMosaicTpuRaggedDot(config=dlhs_config)
drhs_op = PallasMosaicTpuRaggedDot(config=drhs_config)

# 3. Create custom vjp function
custom_vjp = functools.partial(base.vjp, dlhs_ragged_dot=dlhs_op, drhs_ragged_dot=drhs_op)

# 4. Create forward op with custom vjp and tiling from flags
fwd_config = Config(tile_m=tiling[0], tile_k=tiling[1], tile_n=tiling[2])
fwd_op = PallasMosaicTpuRaggedDot(config=fwd_config, vjp=custom_vjp)

output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation=(fwd_op,), # Pass the configured op directly!
)
else:
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax_group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
elif self.config.megablox: # Older forked megablox
output = mblx.gmm(
lhs=inputs,
Expand Down
29 changes: 21 additions & 8 deletions src/maxtext/models/deepseek_batchsplit_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from maxtext.layers import quantizations
import qwix.pallas as qpl
import tokamax
from tokamax import config as tokamax_config
from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config


@functools.partial(
Expand Down Expand Up @@ -962,14 +964,25 @@ def gmm(
qwix_rule=quantizations.get_fp8_full_qwix_rule_w_sparsity(config)[0],
)
else:
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
)
if config.tokamax_gmm_autotune:
with tokamax_config.autotuning_cache_miss_fallback("autotune"):
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
)
else:
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=tokamax.RaggedDotGroupSizes(group_sizes, len(inputs)),
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=preferred_element_type,
implementation="mosaic",
)
return output

gmm_fn = functools.partial(gmm, group_sizes=group_sizes, preferred_element_type=config.dtype)
Expand Down
Loading