Skip to content

Commit 26a122a

Browse files
committed
fix: update
1 parent 8b7313a commit 26a122a

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,10 @@ def _create_scanned_layers(self, decoder_layer_class, length: int, metadata_axis
413413

414414
# Create a reference layer to capture the module graph structure (graphdef).
415415
# This layer's params are discarded — only the structure is kept.
416-
ref_rngs = nnx.Rngs(0)
416+
# Must use the first slice of the forked rngs (not a dummy Rngs(0)) so the
417+
# graphdef has the same number of RNG state leaves as the scan-created layers.
418+
first_rng_state = jax.tree.map(lambda x: x[0], rngs_state)
419+
ref_rngs = nnx.merge(rngs_graphdef, first_rng_state)
417420
ref_layer = decoder_layer_class(
418421
config=self.config, mesh=self.mesh, quant=self.quant,
419422
model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs

src/maxtext/layers/nnx_wrappers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,6 @@ def _module_kwargs():
435435
_fix_for_qwix_quantization(module)
436436
method_fn = _get_module_method(module, nnx_method)
437437
out = method_fn(module, *args, **kwargs)
438-
# Free the NNX module eagerly to avoid holding both NNX params and
439-
# Linen variable copies in memory simultaneously during init tracing.
440-
del module
441438
return out
442439

443440
# create the nnx module

0 commit comments

Comments
 (0)