Skip to content

Commit 669dc01

Browse files
committed
[WIP] NNX: fix model and test compatibility issues
- Replace nn.Dropout with linears.Dropout in gpt_oss and olmo3 decoder layers - Add num_activations logical axis rule to base.yml - Fix integration and unit tests for NNX compatibility I will relocate these files accordingly once the work is done.
1 parent 3694725 commit 669dc01

19 files changed

Lines changed: 90 additions & 33 deletions

src/maxtext/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ logical_axis_rules: [
497497
['paged_kv_head_dim_size', []],
498498
['dense_layers', []],
499499
['moe_layers', []],
500+
['num_activations', []],
500501
['engram_dim', ['tensor']],
501502
['mhc', []],
502503
['diloco', 'diloco'],

src/maxtext/models/gpt_oss.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from maxtext.common.common_types import AttentionType, Config
2929
from maxtext.layers import attentions
3030
from maxtext.layers import initializers
31+
from maxtext.layers import linears
3132
from maxtext.layers import moe
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
@@ -130,6 +131,8 @@ def __init__(
130131
rngs=rngs,
131132
)
132133

134+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
135+
133136
def __call__(
134137
self,
135138
inputs,
@@ -181,7 +184,7 @@ def __call__(
181184
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
182185

183186
layer_output = mlp_lnx + intermediate_inputs
184-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
187+
layer_output = self.dropout(layer_output, deterministic=deterministic)
185188

186189
layer_output = nn.with_logical_constraint(
187190
layer_output,

src/maxtext/models/olmo3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from maxtext.common.common_types import AttentionType, Config
3030
from maxtext.layers import attentions
3131
from maxtext.layers import initializers
32+
from maxtext.layers import linears
3233
from maxtext.layers import nnx_wrappers
3334
from maxtext.layers import quantizations
3435
from maxtext.layers.attentions import Attention
@@ -140,6 +141,8 @@ def __init__(
140141
rngs=rngs,
141142
)
142143

144+
self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs)
145+
143146
def __call__(
144147
self,
145148
inputs,
@@ -193,7 +196,7 @@ def __call__(
193196
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
194197

195198
layer_output = mlp_lnx + intermediate_inputs
196-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
199+
layer_output = self.dropout(layer_output, deterministic=deterministic)
197200

198201
layer_output = nn.with_logical_constraint(
199202
layer_output,

src/maxtext/trainers/post_train/sft/train_sft.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747

4848
from orbax import checkpoint as ocp
4949

50-
from tunix.sft import metrics_logger, peft_trainer, profiler
51-
5250
from maxtext.configs import pyconfig
5351
from maxtext.trainers.pre_train.train import loss_fn
5452
from maxtext.common.goodput import (
@@ -74,6 +72,8 @@ def get_tunix_config(mt_config):
7472
Returns:
7573
A Tunix `TrainingConfig` object.
7674
"""
75+
from tunix.sft import metrics_logger, peft_trainer, profiler # pylint: disable=g-import-not-at-top
76+
7777
# Checkpointing configurations
7878
checkpointing_options = ocp.CheckpointManagerOptions(
7979
save_interval_steps=mt_config.checkpoint_period,
@@ -140,6 +140,8 @@ def loss_func(model, inputs, inputs_position, inputs_segmentation, targets, targ
140140

141141
def setup_trainer_state(mt_config, goodput_recorder=None):
142142
"""Set up prerequisites for training loop."""
143+
from tunix.sft import peft_trainer # pylint: disable=g-import-not-at-top
144+
143145
tunix_config = get_tunix_config(mt_config)
144146

145147
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):

tests/integration/aot_identical_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
179179
"enable_checkpointing=False",
180180
"dump_jaxpr=True",
181181
"dump_jaxpr_delete_local_after=False",
182+
"skip_first_n_steps_for_profiler=0",
182183
]
183184
if extra_args:
184185
shared_args.extend(extra_args)

tests/integration/checkpointing_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
9393
f"dataset_type={dataset_type}",
9494
"async_checkpointing=False",
9595
f"attention={attention_type}",
96+
"profiler=''",
9697
]
9798
+ model_params
9899
+ pathways_command
@@ -135,19 +136,19 @@ def run_checkpointing(hardware, attention_type):
135136
# Determine dataset path/pattern depending on decoupled mode.
136137
gcsfuse_pattern = "/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*"
137138
local_decoupled_root = os.path.join(
138-
MAXTEXT_PKG_DIR, "..", "tests", "assets", "local_datasets", "c4_en_dataset_minimal", "c4", "en", "3.0.1"
139+
MAXTEXT_PKG_DIR, "..", "..", "tests", "assets", "local_datasets", "c4_en_dataset_minimal", "c4", "en", "3.0.1"
139140
)
140141
local_pattern = os.path.join(local_decoupled_root, "c4-train.array_record*")
141142
selected_pattern = gcsfuse_pattern
142143
dataset_path = "/tmp/gcsfuse"
143144

144-
if is_decoupled():
145+
if not glob.glob(gcsfuse_pattern):
145146
# Prefer local minimal dataset if gcsfuse data absent
146-
if not glob.glob(gcsfuse_pattern) and glob.glob(local_pattern):
147+
if glob.glob(local_pattern):
147148
selected_pattern = local_pattern
148-
dataset_path = os.path.join(MAXTEXT_PKG_DIR, "..", "tests", "assets", "local_datasets")
149-
elif not glob.glob(gcsfuse_pattern) and not glob.glob(local_pattern):
150-
pytest.skip("No grain ArrayRecord shards found for checkpointing test in decoupled mode.")
149+
dataset_path = os.path.join(MAXTEXT_PKG_DIR, "..", "..", "tests", "assets", "local_datasets")
150+
else:
151+
pytest.skip("No grain ArrayRecord shards found for checkpointing test.")
151152

152153
grain_command = [
153154
"grain_worker_count=0",

tests/integration/decode_tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class DecodeTests(unittest.TestCase):
4949
"max_target_length=128",
5050
"per_device_batch_size=1",
5151
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
52+
"profiler=''",
53+
"pure_nnx=False",
5254
],
5355
"int8": [ # tests decode with int8 quantization
5456
None,
@@ -64,6 +66,8 @@ class DecodeTests(unittest.TestCase):
6466
"quantization=int8",
6567
"quantize_kvcache=True",
6668
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
69+
"profiler=''",
70+
"pure_nnx=False",
6771
],
6872
"pdb_lt_1": [ # tests decode with per_device_batch_size < 1
6973
None,
@@ -77,6 +81,8 @@ class DecodeTests(unittest.TestCase):
7781
"max_target_length=128",
7882
"per_device_batch_size=.25",
7983
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
84+
"profiler=''",
85+
"pure_nnx=False",
8086
],
8187
"decode_sampling": [
8288
None,

tests/integration/generate_param_only_checkpoint_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
5454
f"attention={attention_type}",
5555
"max_target_length=128",
5656
"per_device_batch_size=1",
57+
"profiler=''",
58+
"pure_nnx=False",
5759
] + model_config
5860

5961
pathways_command = []
@@ -72,6 +74,7 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
7274
dataset_type="tfds",
7375
dataset_path=dataset_path,
7476
)
77+
+ ["pure_nnx=False"]
7578
)
7679
state_path = f"{base_output_directory}/runner_{run_date}/checkpoints/0/items"
7780

tests/integration/gradient_accumulation_test.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from maxtext.common.gcloud_stub import is_decoupled
2929
from maxtext.trainers.pre_train.train import main as train_main
3030
from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT
31-
from maxtext.trainers.post_train.sft.train_sft_deprecated import main as sft_main
31+
from maxtext.trainers.post_train.sft.train_sft import main as sft_main
3232

33-
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory
33+
from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory, get_post_train_test_config_path
3434

3535

3636
def generate_random_string(length=10):
@@ -151,9 +151,8 @@ def test_sft_grad_accumulate_same_loss(self):
151151
sft_main(
152152
[
153153
None,
154-
get_test_config_path(),
155-
"base_output_directory=gs://runner-maxtext-logs",
156-
"dataset_path=gs://maxtext-dataset",
154+
get_post_train_test_config_path("sft"),
155+
f"base_output_directory={self.base_output_directory}",
157156
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off).
158157
"enable_checkpointing=False",
159158
"enable_goodput_recording=False",
@@ -162,6 +161,6 @@ def test_sft_grad_accumulate_same_loss(self):
162161
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
163162
"steps=3",
164163
"gradient_accumulation_steps=2",
165-
"use_sft=True",
164+
"dataset_type=synthetic",
166165
]
167166
)

tests/integration/smoke/inference_microbenchmark_smoke_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def test(self):
5353
"weight_dtype=bfloat16",
5454
"attention=dot_product",
5555
"skip_jax_distributed_system=True",
56+
"profiler=''",
57+
"pure_nnx=False",
58+
"enable_nnx=False",
5659
]
5760
)
5861
run_benchmarks(config)

0 commit comments

Comments
 (0)