Skip to content

Commit 80ebfcb

Browse files
Fix unit test
1 parent 46dd5a6 commit 80ebfcb

File tree

5 files changed

+116
-56
lines changed

5 files changed

+116
-56
lines changed

src/maxtext/layers/initializers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9494
out_sharding = metadata["sharding"]
9595

9696
if out_sharding is not None:
97+
if nnx.PARTITION_NAME in metadata:
98+
partition_name = metadata[nnx.PARTITION_NAME]
99+
scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0
100+
101+
sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding)
102+
if partition_name not in sharding_list:
103+
sharding_list.insert(scan_axis, partition_name)
104+
105+
out_sharding = tuple(sharding_list)
106+
97107
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
98108
variable.value,
99109
out_sharding, # type: ignore[arg-type]

src/maxtext/layers/nnx_decoders.py

Lines changed: 102 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171

7272
class NNXDecoderLayer(nnx.Module):
7373
"""
74-
Transformer decoder layer converted to NNX.
74+
Transformer decoder layer converted to NNX
7575
"""
7676

7777
def __init__(
@@ -307,11 +307,10 @@ def __init__(
307307
dense_cls, moe_cls = decoder_block_classes
308308

309309
num_dense = config.first_num_dense_layers
310-
self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs)
311-
310+
self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, metadata_axis_name="dense_layers", rngs=rngs)
312311
num_moe = config.num_decoder_layers - config.first_num_dense_layers
313-
314-
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs)
312+
self.moe_layers = self._create_scanned_layers(moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs)
313+
315314
elif self.is_gemma3:
316315
attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN)
317316
scan_length = config.num_decoder_layers // attention_pattern_length
@@ -323,7 +322,9 @@ def __init__(
323322
RemattedGemma3Block = gemma3.Gemma3ScannableBlock
324323

325324
if scan_length > 0:
326-
self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs)
325+
self.layers = self._create_scanned_layers(
326+
RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs
327+
)
327328
self.layers_remainder = RemattedGemma3Block(
328329
config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs
329330
) # pytype: disable=wrong-keyword-args
@@ -338,7 +339,9 @@ def __init__(
338339
}
339340

340341
if num_layers > 0:
341-
self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs)
342+
self.layers = self._create_scanned_layers(
343+
layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs
344+
)
342345
else:
343346
self.layers = nnx.List([])
344347

@@ -390,34 +393,80 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs):
390393
)
391394
return nnx_wrappers.ToNNX(layer_linen, rngs=rngs)
392395

393-
def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs):
394-
"""Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
396+
def _create_scanned_layers(self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs):
397+
"""Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization.
395398
396-
def create_layer_fn(rng):
397-
layer = decoder_layer_class(
398-
config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs
399-
)
400-
401-
return layer
399+
Uses jax.lax.scan instead of nnx.vmap to reduce peak memory during initialization.
400+
With vmap, all layers' parameters are created simultaneously (O(N) peak memory).
401+
With scan, parameters are created one layer at a time (O(1) peak intermediate memory),
402+
which prevents OOM on memory-constrained devices like TPU v6e-4.
403+
"""
404+
scan_axis = self.config.param_scan_axis
402405

403-
# Workaround for Deepseek MTP test failure.
404-
# TODO: Handle this properly.
406+
# Fork rngs to get per-layer RNG states for scanning
405407
try:
406408
forked_rngs = rngs.fork(split=length)
407-
408409
except: # pylint: disable=bare-except
409410
pass
410411

411-
out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0})
412-
layers_vmapped = nnx.vmap(
413-
create_layer_fn,
414-
in_axes=0,
415-
out_axes=out_axes,
416-
axis_name="layers",
417-
transform_metadata={nnx.PARTITION_NAME: "layers"},
418-
)(forked_rngs)
412+
rngs_graphdef, rngs_state = nnx.split(forked_rngs)
413+
414+
# Create a reference layer to capture the module graph structure (graphdef).
415+
# This layer's params are discarded — only the structure is kept.
416+
# Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the
417+
# graphdef has the same number of RNG state leaves as the scan-created layers.
418+
first_rng_state = jax.tree.map(lambda x: x[0], rngs_state)
419+
ref_rngs = nnx.merge(rngs_graphdef, first_rng_state)
420+
ref_layer = decoder_layer_class(
421+
config=self.config, mesh=self.mesh, quant=self.quant,
422+
model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs
423+
)
424+
layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...)
425+
del ref_layer
426+
427+
# Sequentially create each layer's parameters via jax.lax.scan.
428+
# The scan body is traced once; XLA executes it N times with different RNG keys,
429+
# keeping only one layer's intermediate state alive at a time.
430+
def scan_body(carry, rng_state_slice):
431+
layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice)
432+
layer = decoder_layer_class(
433+
config=self.config, mesh=self.mesh, quant=self.quant,
434+
model_mode=self.model_mode, rngs=layer_rngs, **layer_kwargs
435+
)
436+
_, params, rest = nnx.split(layer, nnx.Param, ...)
437+
return carry, (params, rest)
438+
439+
_, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state)
419440

420-
return layers_vmapped
441+
# jax.lax.scan stacks outputs along axis 0. Move params to the configured scan axis.
442+
if scan_axis != 0:
443+
stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params)
444+
445+
# Add partition metadata that nnx.vmap's transform_metadata would normally set.
446+
# This metadata is read by variable_to_logically_partitioned() in initializers.py
447+
# and by nnx.get_partition_spec() (via the updated out_sharding) to produce
448+
# correct sharding specs that include the scan axis dimension.
449+
def _add_scan_metadata(state, axis):
450+
def _update_leaf(leaf):
451+
if isinstance(leaf, nnx.VariableState):
452+
metadata = leaf.get_metadata()
453+
metadata[nnx.PARTITION_NAME] = metadata_axis_name
454+
metadata["param_scan_axis"] = axis
455+
# Insert the scan axis name into out_sharding so that
456+
# nnx.get_partition_spec returns specs matching the actual tensor rank.
457+
# Without this, scanned params are 3D but specs remain 2D.
458+
if "out_sharding" in metadata and metadata["out_sharding"]:
459+
sharding = list(metadata["out_sharding"])
460+
sharding.insert(axis, metadata_axis_name)
461+
metadata["out_sharding"] = tuple(sharding)
462+
return leaf.replace(**metadata)
463+
return leaf
464+
return jax.tree.map(_update_leaf, state, is_leaf=lambda x: isinstance(x, nnx.VariableState))
465+
466+
stacked_params = _add_scan_metadata(stacked_params, scan_axis)
467+
stacked_rest = _add_scan_metadata(stacked_rest, 0)
468+
469+
return nnx.merge(layer_graphdef, stacked_params, stacked_rest)
421470

422471
def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs):
423472
"""Helper to cleanly apply jax.checkpoint to a single unscanned layer or block."""
@@ -439,9 +488,7 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
439488
"""Runs the layer stack using nnx.scan."""
440489
policy = self.get_remat_policy()
441490
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
442-
graphdef, params, state = nnx.split(
443-
layers, nnx.Param, ...
444-
) # state: the mutable state we carry (KV cache, RNGs, etc.)
491+
graphdef, params, state = nnx.split(layers, nnx.Param, ...)
445492

446493
scan_axis = self.config.param_scan_axis
447494
if scan_axis != 0:
@@ -451,6 +498,13 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
451498
sig = inspect.signature(layer_cls.__call__)
452499
valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters}
453500

501+
def _extract_matching_state(template, full):
502+
if isinstance(template, nnx.State):
503+
return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()})
504+
elif isinstance(template, dict):
505+
return {k: _extract_matching_state(v, full[k]) for k, v in template.items()}
506+
return full
507+
454508
def layer_fn(carry, scanned_vars):
455509
current_params, current_state = scanned_vars
456510

@@ -460,20 +514,28 @@ def layer_fn(carry, scanned_vars):
460514
layer = nnx.merge(graphdef, current_params, current_state)
461515
layer_out = layer(carry, *args, **valid_kwargs)
462516
new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out
463-
new_current_state = nnx.state(layer)
464-
517+
518+
new_full_state = nnx.state(layer)
519+
new_current_state = _extract_matching_state(current_state, new_full_state)
520+
521+
# ONLY return non-param state to prevent memory duplication of weights
465522
return new_carry, new_current_state
466523

467524
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
468525

469-
final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state))
526+
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
470527

471528
if scan_axis != 0:
472-
scanned_params, scanned_other = scanned_state.split(nnx.Param, ...)
473-
scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params)
474-
scanned_state = nnx.State.merge(scanned_params, scanned_other)
475-
476-
return final_carry, nnx.merge(graphdef, scanned_state)
529+
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)
530+
531+
scanned_state = nnx.State.merge(params, scanned_other)
532+
# Update the existing module in-place rather than creating a new one.
533+
# Creating a new module via nnx.merge and reassigning (self.layers = new_module)
534+
# would replace a child node in the NNX graph, which is detected as a graph
535+
# structure mutation when the parent module is inside a JAX transformation
536+
# (e.g., nnx.jit in PeftTrainer). In-place update preserves object identity.
537+
nnx.update(layers, scanned_state)
538+
return final_carry, layers
477539

478540
def get_decoder_layers(self):
479541
"""Retrieves decoder layer classes based on config using a dictionary lookup."""
@@ -1159,7 +1221,7 @@ def decoder_as_linen(
11591221
model_mode: str,
11601222
quant: None | Quant = None,
11611223
):
1162-
"""Creates a Decoder module."""
1224+
"""Creates a Decoder module"""
11631225
module = nnx_wrappers.to_linen(
11641226
NNXDecoder,
11651227
config=config,
@@ -1171,4 +1233,4 @@ def decoder_as_linen(
11711233
abstract_init=False,
11721234
metadata_fn=initializers.variable_to_logically_partitioned,
11731235
)
1174-
return module
1236+
return module

src/maxtext/layers/quantizations.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from aqt.jax.v2 import tiled_dot_general
2727
from aqt.jax.v2 import calibration
2828

29+
from maxtext.layers import nnx_wrappers
2930
import qwix
3031
from qwix._src.core import dot_general_qt
3132

@@ -285,7 +286,7 @@ class Fp8Quantization(Quantization):
285286

286287
def dot_general_cls(self, mesh_axes: Tuple[str, ...] = ()):
287288
"""Returns dot_general configured with aqt params."""
288-
return nn.Fp8DirectDotGeneralOp
289+
return nnx_wrappers.ToNNX(nn.Fp8DirectDotGeneralOp)
289290

290291
def einsum(self, dtype: DType = jnp.float32):
291292
return _Fp8EinsumWrapper(dtype=dtype)

src/maxtext/models/models.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from maxtext.layers.decoders import Decoder
3434
from maxtext.layers.embeddings import Embed, embed_as_linen
3535
from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen
36-
from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen
36+
from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen
3737
from maxtext.layers.quantizations import AqtQuantization as Quant
3838
from maxtext.multimodal import processor as mm_processor
3939
from maxtext.utils import max_utils
@@ -376,25 +376,12 @@ def __init__(
376376
# For MTP, we use the DecoderLayer blueprint to ensure architectural consistency.
377377
# By convention, this is the last layer in the list.
378378
mtp_layer = layer_types[-1]
379-
mtp_block_linen = multi_token_prediction_block_as_linen(
379+
self.mtp_block = MultiTokenPredictionBlock(
380380
config=self.config,
381381
mesh=self.mesh,
382382
transformer_layer_module=mtp_layer,
383383
decoder=self.decoder,
384384
rngs=rngs,
385-
name="mtp_block",
386-
)
387-
self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs)
388-
389-
self.mtp_block.lazy_init(
390-
shared_embedding=self.token_embedder,
391-
main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype),
392-
input_ids=jnp.ones((1, 1), dtype=jnp.int32),
393-
target_ids=jnp.ones((1, 1), dtype=jnp.int32),
394-
target_mask=jnp.ones((1, 1), dtype=jnp.int32),
395-
position_ids=jnp.ones((1, 1), dtype=jnp.int32),
396-
decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32),
397-
deterministic=True,
398385
)
399386

400387
def no_op(self, *args, **kwargs):

0 commit comments

Comments
 (0)