Skip to content

Commit bdc6ec6

Browse files
author
Charles Li
committed
Support gradient_accumulation and align to latest train.py
1 parent 0a9b2fa commit bdc6ec6

2 files changed

Lines changed: 271 additions & 33 deletions

File tree

src/maxtext/trainers/pre_train/nnx_train.py

Lines changed: 150 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@
9494
maybe_record_goodput,
9595
record_goodput,
9696
)
97-
from maxtext.common.metric_logger import MetricLogger
97+
from maxtext.common.metric_logger import MetricLogger, record_activation_metrics
9898
from maxtext.configs import pyconfig
9999
from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator
100100
from maxtext.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
101101
from maxtext.optimizers import optimizers
102102
from maxtext.utils import exceptions, max_logging, max_utils, maxtext_utils, model_creation_utils, sharding
103103
from maxtext.utils.globals import EPS
104+
from maxtext.utils.gradient_accumulation import nnx_gradient_accumulation_loss_and_grad
104105
from maxtext.utils.rampup_batch import create_rampup_manager
105106

106107
_diag_modules = _cloud_diag()
@@ -128,7 +129,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
128129
Returns:
129130
(loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
130131
"""
131-
rng1, aqt_rng = jax.random.split(dropout_rng)
132+
# rng1, aqt_rng = jax.random.split(dropout_rng)
132133

133134
# Trim to micro-batch size (handles per_device_batch_size < 1 cases)
134135
# decimate proportion of data when per_device_batch_size<1
@@ -189,6 +190,24 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
189190
mtp_loss = calculate_mtp_loss(intermediate_outputs, config)
190191
loss += mtp_loss
191192

193+
# get indexer loss
194+
indexer_loss = 0.0
195+
if config.use_sparse_indexer and config.indexer_loss_scaling_factor > 0.0:
196+
indexer_losses = []
197+
# Extract 'indexer_loss' from model intermediates.
198+
# We check for paths ending in ('self_attention', 'indexer_loss').
199+
# This handles varying paths caused by different layer names.
200+
for path, val in jax.tree_util.tree_leaves_with_path(intermediate_outputs):
201+
path_keys = tuple(k.key for k in path if hasattr(k, "key"))
202+
if path_keys[-2:] == ("self_attention", "indexer_loss"):
203+
indexer_losses.append(jnp.ravel(val))
204+
205+
if indexer_losses:
206+
indexer_loss = jnp.mean(jnp.concatenate(indexer_losses))
207+
loss += indexer_loss
208+
else:
209+
max_logging.debug("No indexer loss found.")
210+
192211
# get MoE load balance loss
193212
moe_lb_loss = 0.0
194213
if config.num_experts > 1:
@@ -228,29 +247,12 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
228247
"z_loss": total_z_loss,
229248
"total_weights": total_weights,
230249
"moe_lb_loss": moe_lb_loss,
250+
"indexer_loss": indexer_loss,
231251
"moe_bias_updates": moe_bias_updates,
232252
"mtp_loss": mtp_loss,
233253
}
234254
return loss, aux
235255

236-
# Zero out padding positions
237-
target_mask = batch["targets_segmentation"] != 0
238-
xent = xent * target_mask
239-
z_loss = z_loss * target_mask
240-
241-
total_loss = jnp.sum(xent)
242-
total_weights = jnp.sum(target_mask)
243-
total_z_loss = jnp.sum(z_loss) / (total_weights + EPS)
244-
245-
loss = total_loss / (total_weights + EPS)
246-
247-
aux = {
248-
"total_loss": total_loss,
249-
"z_loss": total_z_loss,
250-
"total_weights": total_weights,
251-
}
252-
return loss, aux
253-
254256

255257
# ---------------------------------------------------------------------------
256258
# Train / eval steps (purely functional, JIT-able)
@@ -283,41 +285,139 @@ def train_step(
283285
"""
284286
model: nnx.Module = nnx.merge(model_graphdef, model_state)
285287
optimizer: nnx.Optimizer = nnx.merge(opt_graphdef, opt_state)
288+
if config.use_dpo:
289+
# Need impl on NNX
290+
pass
291+
# state, reference_params = _split_dpo_state(state)
292+
# state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings)
293+
# extra_dpo_args = [reference_params]
294+
# loss_fn = dpo_loss_fn
286295

287296
# Compute loss and gradients w.r.t. model parameters.
288297
# nnx.value_and_grad differentiates only through nnx.Param variables,
289298
# keeping non-differentiable state (RNGs, cache, etc.) frozen.
290-
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
291-
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
299+
if config.gradient_accumulation_steps > 1:
300+
loss, aux, raw_grads = nnx_gradient_accumulation_loss_and_grad(loss_fn, model, config, data, dropout_rng)
301+
else:
302+
if config.optimizer_memory_host_offload:
303+
# Need impl on NNX
304+
pass
305+
# if config.use_dpo:
306+
# reference_params = jax.device_put(
307+
# reference_params,
308+
# max_utils.with_memory_kind(reference_params_sharding, "device"),
309+
# )
310+
# extra_dpo_args = [reference_params]
311+
if config.shard_optimizer_over_data:
312+
# Need impl on NNX
313+
pass
314+
# params = jax.tree.map(
315+
# functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode),
316+
# params,
317+
# params_shardings,
318+
# )
319+
grad_fn = nnx.value_and_grad(loss_fn, argnums=0, has_aux=True)
320+
(loss, aux), raw_grads = grad_fn(model, config, data, dropout_rng, is_train=True)
292321

293322
# Cast gradients to configured dtype before clipping / accumulation
294323
raw_grads = jax.tree.map(
295324
lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x,
296325
raw_grads,
297326
)
327+
intermediate_outputs = aux["intermediate_outputs"]
328+
total_weights = aux["total_weights"]
329+
moe_lb_loss = aux["moe_lb_loss"]
330+
indexer_loss = aux["indexer_loss"]
331+
z_loss = aux["z_loss"]
332+
moe_bias_updates = aux["moe_bias_updates"]
333+
mtp_loss = aux["mtp_loss"]
298334

299335
# Gradient clipping (implemented directly to avoid Linen TrainState dependency)
300336
if config.gradient_clipping_threshold > 0:
301337
clip_tx = optax.clip_by_global_norm(config.gradient_clipping_threshold)
302338
grads, _ = clip_tx.update(raw_grads, clip_tx.init(raw_grads), None)
303339
else:
304340
grads = raw_grads
341+
if config.optimizer_memory_host_offload:
342+
# Need impl on NNX
343+
pass
344+
# state = state.replace(
345+
# opt_state=jax.device_put(
346+
# state.opt_state,
347+
# jax.tree_util.tree_map(
348+
# lambda x: x.with_memory_kind(kind="device"),
349+
# state_mesh_shardings.opt_state,
350+
# ),
351+
# )
352+
# )
353+
# Move all parameters to device before optimizer update
354+
if config.parameter_memory_host_offload:
355+
max_logging.log("\nMoving all parameters to device before optimizer update")
356+
# Need impl on NNX
357+
# def move(path, value):
358+
# max_logging.log(f"train.py: Moving f{path} to device")
359+
# return value.with_memory_kind(kind="device")
360+
361+
# state = state.replace(
362+
# params=jax.device_put(
363+
# state.params,
364+
# jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
365+
# )
366+
# )
305367

306368
# NNX 0.11+: update takes (model, grads) explicitly.
307369
optimizer.update(model, grads)
308370

309371
new_model_state = nnx.state(model)
310372
new_opt_state = nnx.state(optimizer)
311373

374+
# Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
375+
if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None:
376+
# Need impl on NNX
377+
pass
378+
# target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
379+
# Flax 'sow' returns a tuple, so we take the first element [0].
380+
# Updates the shape to be aligned with state.
381+
# moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose()
382+
# new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates)
383+
312384
scalar_metrics = {
313385
"learning/loss": loss,
314-
"learning/z_loss": aux["z_loss"],
315-
"learning/total_weights": aux["total_weights"],
316-
"learning/grad_norm": max_utils.l2norm_pytree(grads),
317-
"learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads),
318-
"learning/param_norm": max_utils.l2norm_pytree(nnx.state(model, nnx.Param)),
386+
"learning/z_loss": z_loss,
387+
"learning/moe_lb_loss": moe_lb_loss,
388+
"learning/indexer_loss": indexer_loss,
389+
"learning/mtp_loss": mtp_loss,
390+
"learning/total_weights": total_weights,
319391
}
320-
metrics = {"scalar": scalar_metrics, "scalars": {}}
392+
if config.use_qk_clip:
393+
# Apply QK-Clip
394+
# Need impl on NNX
395+
pass
396+
# new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
397+
398+
# Report max_logits metric
399+
# global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
400+
# if global_max_logit is not None:
401+
# scalar_metrics["learning/max_logits"] = global_max_logit
402+
403+
if not config.optimizer_memory_host_offload:
404+
scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads)
405+
scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads)
406+
scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(nnx.state(model, nnx.Param))
407+
if config.use_dpo:
408+
scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"]
409+
metrics = {
410+
"scalar": scalar_metrics,
411+
"scalars": {},
412+
}
413+
414+
if config.record_internal_nn_metrics:
415+
record_activation_metrics(metrics, intermediate_outputs, config)
416+
417+
if config.use_dpo:
418+
# Need impl on NNX
419+
pass
420+
# new_state = _merge_dpo_state(new_state, reference_params)
321421
return (new_model_state, new_opt_state), metrics
322422

323423

@@ -351,6 +451,7 @@ def eval_step(
351451
z_loss = aux["z_loss"]
352452
total_weights = aux["total_weights"]
353453
moe_lb_loss = aux["moe_lb_loss"]
454+
indexer_loss = aux["indexer_loss"]
354455
mtp_loss = aux["mtp_loss"]
355456
metrics = {
356457
"scalar": {
@@ -359,6 +460,7 @@ def eval_step(
359460
"evaluation/total_loss": total_loss,
360461
"evaluation/total_weights": total_weights,
361462
"evaluation/moe_lb_loss": moe_lb_loss,
463+
"evaluation/indexer_loss": indexer_loss,
362464
"evaluation/mtp_loss": mtp_loss,
363465
"evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate,
364466
},
@@ -416,8 +518,8 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh):
416518
_, opt_state = nnx.split(optimizer)
417519

418520
@functools.partial(jax.jit, out_shardings=(model_shardings, opt_shardings))
419-
def shard_states(ms, os):
420-
return ms, os
521+
def shard_states(mshard, oshard):
522+
return mshard, oshard
421523

422524
with mesh:
423525
model_state, opt_state = shard_states(model_state, opt_state)
@@ -609,7 +711,9 @@ def train_loop(config, recorder, state=None):
609711
shaped_batch = maxtext_utils.get_shaped_batch(config)
610712
init_rng = jax.random.PRNGKey(config.init_weights_seed)
611713
example_rng = jax.jit(jax.random.fold_in)(init_rng, 0)
612-
if config.compiled_trainstep_file == "":
714+
# Need imple below func on NNX
715+
# maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng))
716+
if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded
613717
compiled = p_train_step.lower(model_state, opt_state, shaped_batch, example_rng).compile()
614718
compiled_stats = compiled.memory_analysis()
615719
max_utils.print_compiled_memory_stats(compiled_stats)
@@ -625,14 +729,14 @@ def train_loop(config, recorder, state=None):
625729
_job_completed_gracefully = False
626730
try:
627731
last_step_completion = datetime.datetime.now()
732+
max_logging.info(f"Entering train loop from start_step={start_step}")
628733

629734
for step in np.arange(start_step, config.steps):
630735
prof.maybe_activate_profiler(step, opt_state)
631736

632737
with jax.profiler.StepTraceAnnotation("train", step_num=step):
633738
example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
634739
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
635-
636740
with maybe_record_goodput(recorder, GoodputEvent.STEP, step):
637741
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
638742
(model_state, opt_state), metrics = p_train_step(model_state, opt_state, example_batch, nextrng)
@@ -650,15 +754,18 @@ def train_loop(config, recorder, state=None):
650754
and (step + 1) % config.eval_interval == 0
651755
):
652756
assert eval_data_iterator
757+
# Explicitly reset the eval iterator and counters before starting the eval loop
653758
eval_data_iterator.reset()
654759
metric_logger.reset_eval_metrics()
760+
655761
eval_step_count = 0
656762
for eval_batch in eval_data_iterator:
657763
if config.eval_steps > 0 and eval_step_count >= config.eval_steps:
658764
break
659765
with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules):
660766
eval_metrics = p_eval_step(model_state, eval_batch, nextrng)
661767
metric_logger.record_eval_metrics(step, metrics=eval_metrics)
768+
max_logging.log(f"Completed eval step {eval_step_count}")
662769
eval_step_count += 1
663770

664771
metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count)
@@ -679,6 +786,7 @@ def train_loop(config, recorder, state=None):
679786
checkpoint_manager, model_state, opt_state, config, data_iterator, step=int(config.steps - 1)
680787
)
681788
if checkpoint_manager is not None:
789+
# in case the last checkpoint_period checkpoint is still in progress
682790
checkpoint_manager.wait_until_finished()
683791

684792
_job_completed_gracefully = True
@@ -728,8 +836,10 @@ def initialize(argv: Sequence[str]):
728836
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
729837
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
730838

839+
# Create the Goodput recorder
731840
recorder = create_goodput_recorder(config)
732841

842+
# Stack traces configurations
733843
debug_config = debug_configuration.DebugConfig(
734844
stack_trace_config=stack_trace_configuration.StackTraceConfig(
735845
collect_stack_trace=config.collect_stack_trace,
@@ -742,13 +852,20 @@ def initialize(argv: Sequence[str]):
742852

743853

744854
def run(config, recorder, diagnostic_config):
745-
"""Run the NNX training job."""
855+
"""Run the NNX training job.
856+
857+
In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping.
858+
"""
859+
# Use nullcontext when diagnostics are stubbed or in decoupled mode
746860
diagnostics_context = (
747861
contextlib.nullcontext()
748862
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag"
749863
else diagnostic.diagnose(diagnostic_config)
750864
)
751865

866+
if is_decoupled() or getattr(diagnostic, "__class__", None).__name__ == "_StubDiag":
867+
max_logging.log("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper.")
868+
752869
with (
753870
diagnostics_context,
754871
max_utils.maybe_get_transformer_engine_context(config),

0 commit comments

Comments
 (0)