Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/integration/tunix/tunix_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __call__(
attention_mask: Optional[Array], # [B, L, L] or None
decoder_segment_ids: Optional[Array] = None,
output_hidden_states: bool = False, # ignored
forced_routed_experts: Optional[Array] = None,
) -> Tuple[Array, None]:
"""Forward compatible with Tunix Trainers default loss.
Returns logits, None.
Expand All @@ -67,6 +68,7 @@ def __call__(
decoder_input_tokens=input_tokens,
decoder_positions=positions,
decoder_segment_ids=decoder_segment_ids,
forced_routed_experts=forced_routed_experts,
)
return logits, None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,4 +323,4 @@ def load_weights(self, rng_key: jax.Array) -> None:
model = model_creation_utils.from_pretrained(
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
)
self.model = nnx.data(model)
self.model = nnx.data(model)
55 changes: 42 additions & 13 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,15 @@ def __call__(
kv_caches: list[jax.Array] | None = None,
attention_metadata=None,
deepstack_visual_embeds: None | list[jnp.ndarray] = None,
forced_routed_experts: jnp.ndarray | None = None,
):
cfg = self.config
mesh = self.mesh
assert decoder_input_tokens.ndim == 2 # [batch, len]

if cfg.scan_layers and forced_routed_experts is not None:
raise NotImplementedError("Forced routing with scanned layers is not supported yet.")

# [batch, length] -> [batch, length, emb_dim]
y = self._apply_embedding(
shared_embedding,
Expand Down Expand Up @@ -1061,6 +1065,10 @@ def __call__(
global_layer_idx = global_layer_idx_offset + index
kv_cache = kv_caches[index] if kv_caches is not None else None
input_tokens = decoder_input_tokens if cfg.engram_layers else None
current_forced_routed_experts = None
if forced_routed_experts is not None and layer_prefix == "moe_layers":
current_forced_routed_experts = forced_routed_experts[:, :, index, :]

y, kv_cache = layer(
config=cfg,
mesh=mesh,
Expand All @@ -1080,11 +1088,13 @@ def __call__(
kv_cache=kv_cache,
attention_metadata=attention_metadata,
decoder_input_tokens=input_tokens,
forced_routed_experts=current_forced_routed_experts,
)
if kv_caches is not None and kv_cache is not None:
kv_caches[index] = kv_cache
global_layer_idx_offset += num_layers
else:
moe_lyr_idx = 0
for lyr in range(cfg.num_decoder_layers):
RemattedBlockLayer = RemattedBlockLayers[0]
layer_kwargs = {}
Expand Down Expand Up @@ -1121,19 +1131,38 @@ def __call__(
layer = RemattedBlockLayer(
config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs
)
y, returned_cache = layer(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
kv_cache=kv_cache,
attention_metadata=attention_metadata,
**layer_call_kwargs,
)
current_forced_routed_experts = None
is_moe = False
if cfg.decoder_block in (
DecoderBlockType.MIXTRAL,
DecoderBlockType.QWEN3_MOE,
DecoderBlockType.QWEN3_NEXT,
DecoderBlockType.QWEN3_5,
DecoderBlockType.QWEN3_CUSTOM_MOE,
Comment on lines +1136 to +1141
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about DeepSeek or Kimi that have both MoE layers and Dense?

Also Gemma4 has 2 models MoE and Dense, but handled by the same layer DecoderBlockType.Gemma4

For Gemma4 you can have: DecoderBlockType.GEMMA4 and num_experts > 1

):
is_moe = True
elif cfg.decoder_block == DecoderBlockType.LLAMA4:
is_moe = llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step)

if is_moe and forced_routed_experts is not None:
current_forced_routed_experts = forced_routed_experts[:, :, moe_lyr_idx, :]
moe_lyr_idx += 1
elif is_moe:
moe_lyr_idx += 1

call_kwargs = {
"previous_chunk": previous_chunk,
"page_state": page_state,
"slot": slot,
"kv_cache": kv_cache,
"attention_metadata": attention_metadata,
}
call_kwargs.update(layer_call_kwargs)

if is_moe and current_forced_routed_experts is not None:
call_kwargs["forced_routed_experts"] = current_forced_routed_experts

y, returned_cache = layer(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, **call_kwargs)
if kv_caches is not None and returned_cache is not None:
if cfg.decoder_block not in (DecoderBlockType.QWEN3_NEXT, DecoderBlockType.QWEN3_5):
kv_caches[lyr] = returned_cache
Expand Down
Loading
Loading