Skip to content

Commit 618de58

Browse files
committed
fix: replace nnx.stateaxes
1 parent 34063e4 commit 618de58

2 files changed

Lines changed: 23 additions & 11 deletions

File tree

src/maxtext/layers/nnx_pipeline.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,16 @@ def build_batched_rngs(shape):
8484

8585
return nnx.Rngs(**kwargs)
8686

87+
# Define a wrapper that explicitly splits the instantiated layer
88+
def create_layer_fn(rngs):
89+
layer = stage_factory(rngs)
90+
return nnx.split(layer, nnx.Param, ...)
91+
8792
# Vmap over stages natively adds 'layers' metadata to the logical partition spec!
8893
vmap_stages = nnx.vmap(
89-
stage_factory,
94+
create_layer_fn,
9095
in_axes=0,
91-
out_axes=nnx.StateAxes({nnx.Param: 0, ...: 0}),
96+
out_axes=(None, 0, 0), # None for static graphdef, 0 for params, 0 for rest
9297
axis_name=self.spmd_axis_name,
9398
transform_metadata={nnx.PARTITION_NAME: "layers"},
9499
)
@@ -97,14 +102,16 @@ def build_batched_rngs(shape):
97102
vmap_repeats = nnx.vmap(
98103
vmap_stages,
99104
in_axes=0,
100-
out_axes=nnx.StateAxes({nnx.Param: 0, ...: 0}),
105+
out_axes=(None, 0, 0), # Graphdef remains unbatched (None) through the second vmap
101106
transform_metadata={nnx.PARTITION_NAME: "circular_repeats"},
102107
)
103108
batched_rngs = build_batched_rngs((self.config.num_pipeline_repeats, self.num_stages))
104-
self.layers = vmap_repeats(batched_rngs)
109+
graphdef, params, rest = vmap_repeats(batched_rngs)
110+
self.layers = nnx.merge(graphdef, params, rest)
105111
else:
106112
batched_rngs = build_batched_rngs((self.num_stages,))
107-
self.layers = vmap_stages(batched_rngs)
113+
graphdef, params, rest = vmap_stages(batched_rngs)
114+
self.layers = nnx.merge(graphdef, params, rest)
108115

109116
# -------------------------------------------------------------------------
110117
# Sharding Configs

src/maxtext/layers/pipeline.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,11 +1017,14 @@ def build_batched_rngs(shape):
10171017

10181018
return nnx.Rngs(**kwargs)
10191019

1020-
# Vmap over stages natively adds 'layers' metadata to the logical partition spec!
1020+
def create_layer_fn(rngs):
1021+
layer = stage_factory(rngs)
1022+
return nnx.split(layer, nnx.Param, ...)
1023+
10211024
vmap_stages = nnx.vmap(
1022-
stage_factory,
1025+
create_layer_fn,
10231026
in_axes=0,
1024-
out_axes=nnx.StateAxes({nnx.Param: 0, ...: 0}),
1027+
out_axes=(None, 0, 0),
10251028
axis_name=self.spmd_axis_name,
10261029
transform_metadata={nnx.PARTITION_NAME: "layers"},
10271030
)
@@ -1030,14 +1033,16 @@ def build_batched_rngs(shape):
10301033
vmap_repeats = nnx.vmap(
10311034
vmap_stages,
10321035
in_axes=0,
1033-
out_axes=nnx.StateAxes({nnx.Param: 0, ...: 0}),
1036+
out_axes=(None, 0, 0),
10341037
transform_metadata={nnx.PARTITION_NAME: "circular_repeats"},
10351038
)
10361039
batched_rngs = build_batched_rngs((self.config.num_pipeline_repeats, self.num_stages))
1037-
self.layers = vmap_repeats(batched_rngs)
1040+
graphdef, params, rest = vmap_repeats(batched_rngs)
1041+
self.layers = nnx.merge(graphdef, params, rest)
10381042
else:
10391043
batched_rngs = build_batched_rngs((self.num_stages,))
1040-
self.layers = vmap_stages(batched_rngs)
1044+
graphdef, params, rest = vmap_stages(batched_rngs)
1045+
self.layers = nnx.merge(graphdef, params, rest)
10411046

10421047
def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None):
10431048
"""get current repeat from stages"""

0 commit comments

Comments
 (0)