Skip to content

Commit 0d02ff3

Browse files
committed
NNX migration: NNX utils
- Add utils to manipulate the NNX shardings with abstract state of a model - also add unit tests for the utils - Extract mesh creation function to maxtext_utils.get_mesh_from_config() - also add unit tests for this func Note: flax v0.12 has DeprecationWarning in multiple places: - DeprecationWarning: '.value' access is now deprecated. Use variable.get_value() or variable[...] (for [Array]). - DeprecationWarning: 'VariableState' was removed, this is just an alias to 'Variable'. Plase use 'Variable' directly instead. But since the code needs to work with post-training, which currently requires flax v0.11, we didn't change code for these warnings.
1 parent efa44ad commit 0d02ff3

5 files changed

Lines changed: 386 additions & 23 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import functools
1919
import pickle
20+
from typing import Sequence
2021

2122
from flax import linen as nn
2223
from flax.linen import partitioning as nn_partitioning
@@ -26,6 +27,7 @@
2627

2728
from jax.experimental import mesh_utils
2829
from jax.experimental.serialize_executable import deserialize_and_load
30+
from jax.sharding import AxisType, Mesh
2931

3032
import jax
3133
import jax.numpy as jnp
@@ -39,8 +41,9 @@
3941
from MaxText import max_logging
4042
from MaxText import max_utils
4143
from MaxText import multimodal_utils
44+
from MaxText import pyconfig
4245
from MaxText import sharding
43-
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
46+
from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode
4447
from MaxText.inference.page_manager import PageState
4548

4649
OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient"
@@ -1178,3 +1181,27 @@ def print_state_mesh_shardings_params(state, state_sharding, mesh):
11781181
shape = jax.typeof(leaf_val)
11791182
pspec = sharding.remove_size_one_mesh_axis(leaf_sharding.spec, mesh)
11801183
max_logging.log(f"{path_str:.<80} {shape} {tuple(pspec)}")
1184+
1185+
1186+
def get_mesh_from_config(
1187+
config: pyconfig.HyperParameters,
1188+
devices: Sequence[jax.Device] | None = None,
1189+
) -> Mesh:
1190+
"""
1191+
Geh mesh from the configuration.
1192+
1193+
Args:
1194+
config: the configuration
1195+
devices: the devices
1196+
1197+
Returns:
1198+
the device mesh
1199+
"""
1200+
devices_array = create_device_mesh(config, devices)
1201+
1202+
if config.shard_mode == ShardMode.EXPLICIT:
1203+
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
1204+
else:
1205+
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
1206+
1207+
return Mesh(devices_array, config.mesh_axes, axis_types=axis_types)

src/MaxText/maxtext_utils_nnx.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
""" Utils for MaxText NNX. """
15+
16+
from functools import partial
17+
from typing import Any, Callable
18+
19+
from flax import nnx
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding
22+
23+
from MaxText import max_logging
24+
from MaxText import pyconfig
25+
26+
27+
def create_nnx_rngs(
28+
config: pyconfig.HyperParameters, is_training: bool = True, rng_key: jax.Array | None = None
29+
) -> nnx.Rngs:
30+
"""
31+
Create NNX Rngs
32+
33+
Args:
34+
config: the configuration
35+
is_training: if the Rngs are for training
36+
rng_key: the Rng key
37+
38+
Returns:
39+
The NNX Rngs
40+
"""
41+
if rng_key is None:
42+
rng_key = jax.random.PRNGKey(config.init_weights_seed)
43+
44+
if is_training:
45+
return nnx.Rngs(
46+
params=jax.random.fold_in(rng_key, 0), dropout=jax.random.fold_in(rng_key, 1), aqt=jax.random.fold_in(rng_key, 2)
47+
)
48+
return nnx.Rngs(params=rng_key) # disable dropout RNG and aqt for inference
49+
50+
51+
def get_named_sharding_nnx(abstract_state: Any) -> Any:
52+
"""Get named sharding from NNX abstract state.
53+
54+
Args:
55+
abstract_state: NNX model abstract state created from nnx.get_abstract_model.
56+
57+
Returns:
58+
named sharding structure
59+
"""
60+
# Don't use nnx.get_named_sharding() because it constructs new shardings. Instead, we
61+
# get the existing sharding from the abstract_state.
62+
# The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding)
63+
return jax.tree.map(
64+
lambda x: x.sharding,
65+
abstract_state,
66+
is_leaf=lambda x: isinstance(x, jax.ShapeDtypeStruct),
67+
)
68+
69+
70+
def set_named_sharding_nnx(abstract_state: Any, named_sharding: Any) -> Any:
71+
"""Set named sharding to NNX abstract state.
72+
73+
Args:
74+
abstract_state: NNX model abstract state created from nnx.get_abstract_model().
75+
named_sharding: named sharding. It must have the same tree structure with abstract_state.
76+
77+
Returns:
78+
updated abstract_state
79+
"""
80+
return jax.tree.map(lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), abstract_state, named_sharding)
81+
82+
83+
def move_memory_to_host(path: tuple[str, ...], x: NamedSharding) -> NamedSharding:
84+
"""
85+
Change the memory_kind of the NamedSharding to "pinned_host". This function can be
86+
called by jax.tree_util.tree_map_with_path on a NNX state structure.
87+
88+
Args:
89+
path: the tree path tuple
90+
x: the NamedSharding corresponding to the path
91+
92+
Returns:
93+
the NamedSharding with memory_kind set to "pinned_host"
94+
"""
95+
max_logging.log(f"max_utils.py: Moving {path} to host")
96+
# Create the new sharding with the target memory kind
97+
return x.with_memory_kind(kind="pinned_host")
98+
99+
100+
def create_nnx_sharded_model(
101+
abstract_model: nnx.Module,
102+
init_fn: Callable,
103+
mesh: Mesh | None = None,
104+
named_sharding: Any | None = None,
105+
) -> nnx.Module:
106+
"""
107+
Create the model with the given sharding.
108+
109+
Args:
110+
abstract_model: the abstract model
111+
init_fn: the model init function
112+
mesh: the device mesh
113+
named_sharding: the given sharding
114+
115+
Returns:
116+
The initialized sharded model
117+
"""
118+
graphdef, abstract_state = nnx.split(abstract_model)
119+
if named_sharding is None:
120+
# The state leaf is of type jax.ShapeDtypeStruct(shape, dtype, sharding)
121+
# we get the sharding directly from it.
122+
named_sharding = get_named_sharding_nnx(abstract_state)
123+
124+
if mesh is None:
125+
mesh = abstract_model.mesh
126+
127+
# JIT a function that creates the model state with proper sharding from the start.
128+
# By providing out_shardings, we instruct JAX to produce sharded output directly,
129+
# avoiding a large intermediate allocation on a single device.
130+
@partial(jax.jit, out_shardings=named_sharding)
131+
def create_sharded_state():
132+
model = init_fn()
133+
return jax.lax.with_sharding_constraint(nnx.state(model), named_sharding)
134+
135+
# Create the model with sharded parameters.
136+
with jax.set_mesh(mesh):
137+
sharded_state = create_sharded_state()
138+
return nnx.merge(graphdef, sharded_state)

src/MaxText/model_creation_utils.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
from flax import nnx
2222
import flax.linen as nn
2323
import jax
24-
from jax.sharding import Mesh, AxisType
24+
from jax.sharding import Mesh
2525
from MaxText import maxtext_utils
26+
from MaxText import maxtext_utils_nnx
2627
from MaxText import pyconfig
2728
from MaxText.layers import quantizations
28-
from MaxText.common_types import MODEL_MODE_TRAIN, ShardMode
29+
from MaxText.common_types import MODEL_MODE_TRAIN
2930
from MaxText.layers import models
3031
from orbax import checkpoint as ocp
3132
from functools import partial
@@ -39,6 +40,7 @@ def from_config(
3940
mesh: Mesh | None = None,
4041
*,
4142
model_mode: str = MODEL_MODE_TRAIN,
43+
rngs: None = None,
4244
) -> nn.Module:
4345
...
4446

@@ -79,15 +81,7 @@ def from_config(
7981
model = from_config(config)
8082
"""
8183
if mesh is None:
82-
devices_array = maxtext_utils.create_device_mesh(config, devices)
83-
84-
if config.shard_mode == ShardMode.EXPLICIT:
85-
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
86-
else:
87-
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
88-
89-
mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types)
90-
84+
mesh = maxtext_utils.get_mesh_from_config(config, devices)
9185
model = create_model(config, mesh, model_mode=model_mode, rngs=rngs)
9286

9387
# Return only the model
@@ -113,16 +107,10 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
113107

114108
def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None):
115109
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
110+
is_training = model_mode == MODEL_MODE_TRAIN
116111

117112
def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None):
118-
if rng_key is None:
119-
rng_key = jax.random.PRNGKey(config.init_weights_seed)
120-
121-
if model_mode == MODEL_MODE_TRAIN:
122-
rngs = nnx.Rngs(params=rng_key, dropout=1)
123-
else:
124-
rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference
125-
113+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key)
126114
return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)
127115

128116
_create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key)
@@ -135,6 +123,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
135123
if mesh is None:
136124
mesh = abstract_model.mesh
137125

126+
# Note for pure_nnx:
127+
# Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and
128+
# we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen
129+
# LogicallyPartitioned structure.
130+
# In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned
131+
# structure in the abstract state and we can get the sharded state with the following code:
132+
# graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh)
133+
# abstract_model = nnx.merge(graphdef, state)
134+
# model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh)
135+
# sharded_state = nnx.state(model)
136+
138137
# JIT a function that creates the model state with proper sharding from the start.
139138
# By providing out_shardings, we instruct JAX to produce sharded output directly,
140139
# avoiding a large intermediate allocation on a single device.

0 commit comments

Comments
 (0)