9494 maybe_record_goodput ,
9595 record_goodput ,
9696)
97- from maxtext .common .metric_logger import MetricLogger
97+ from maxtext .common .metric_logger import MetricLogger , record_activation_metrics
9898from maxtext .configs import pyconfig
9999from maxtext .input_pipeline .input_pipeline_interface import create_data_iterator
100100from maxtext .layers .multi_token_prediction import calculate_mtp_acceptance_rate , calculate_mtp_loss
101101from maxtext .optimizers import optimizers
102102from maxtext .utils import exceptions , max_logging , max_utils , maxtext_utils , model_creation_utils , sharding
103103from maxtext .utils .globals import EPS
104+ from maxtext .utils .gradient_accumulation import nnx_gradient_accumulation_loss_and_grad
104105from maxtext .utils .rampup_batch import create_rampup_manager
105106
106107_diag_modules = _cloud_diag ()
@@ -128,7 +129,7 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
128129 Returns:
129130 (loss, aux) where loss is a scalar and aux is a dict of auxiliary metrics.
130131 """
131- rng1 , aqt_rng = jax .random .split (dropout_rng )
132+ # rng1, aqt_rng = jax.random.split(dropout_rng)
132133
133134 # Trim to micro-batch size (handles per_device_batch_size < 1 cases)
134135 # decimate proportion of data when per_device_batch_size<1
@@ -189,6 +190,24 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
189190 mtp_loss = calculate_mtp_loss (intermediate_outputs , config )
190191 loss += mtp_loss
191192
193+ # get indexer loss
194+ indexer_loss = 0.0
195+ if config .use_sparse_indexer and config .indexer_loss_scaling_factor > 0.0 :
196+ indexer_losses = []
197+ # Extract 'indexer_loss' from model intermediates.
198+ # We check for paths ending in ('self_attention', 'indexer_loss').
199+ # This handles varying paths caused by different layer names.
200+ for path , val in jax .tree_util .tree_leaves_with_path (intermediate_outputs ):
201+ path_keys = tuple (k .key for k in path if hasattr (k , "key" ))
202+ if path_keys [- 2 :] == ("self_attention" , "indexer_loss" ):
203+ indexer_losses .append (jnp .ravel (val ))
204+
205+ if indexer_losses :
206+ indexer_loss = jnp .mean (jnp .concatenate (indexer_losses ))
207+ loss += indexer_loss
208+ else :
209+ max_logging .debug ("No indexer loss found." )
210+
192211 # get MoE load balance loss
193212 moe_lb_loss = 0.0
194213 if config .num_experts > 1 :
@@ -228,29 +247,12 @@ def loss_fn(model: nnx.Module, config, data: dict[str, jax.Array], dropout_rng:
228247 "z_loss" : total_z_loss ,
229248 "total_weights" : total_weights ,
230249 "moe_lb_loss" : moe_lb_loss ,
250+ "indexer_loss" : indexer_loss ,
231251 "moe_bias_updates" : moe_bias_updates ,
232252 "mtp_loss" : mtp_loss ,
233253 }
234254 return loss , aux
235255
236- # Zero out padding positions
237- target_mask = batch ["targets_segmentation" ] != 0
238- xent = xent * target_mask
239- z_loss = z_loss * target_mask
240-
241- total_loss = jnp .sum (xent )
242- total_weights = jnp .sum (target_mask )
243- total_z_loss = jnp .sum (z_loss ) / (total_weights + EPS )
244-
245- loss = total_loss / (total_weights + EPS )
246-
247- aux = {
248- "total_loss" : total_loss ,
249- "z_loss" : total_z_loss ,
250- "total_weights" : total_weights ,
251- }
252- return loss , aux
253-
254256
255257# ---------------------------------------------------------------------------
256258# Train / eval steps (purely functional, JIT-able)
@@ -283,41 +285,139 @@ def train_step(
283285 """
284286 model : nnx .Module = nnx .merge (model_graphdef , model_state )
285287 optimizer : nnx .Optimizer = nnx .merge (opt_graphdef , opt_state )
288+ if config .use_dpo :
289+ # Need impl on NNX
290+ pass
291+ # state, reference_params = _split_dpo_state(state)
292+ # state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings)
293+ # extra_dpo_args = [reference_params]
294+ # loss_fn = dpo_loss_fn
286295
287296 # Compute loss and gradients w.r.t. model parameters.
288297 # nnx.value_and_grad differentiates only through nnx.Param variables,
289298 # keeping non-differentiable state (RNGs, cache, etc.) frozen.
290- grad_fn = nnx .value_and_grad (loss_fn , argnums = 0 , has_aux = True )
291- (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng , is_train = True )
299+ if config .gradient_accumulation_steps > 1 :
300+ loss , aux , raw_grads = nnx_gradient_accumulation_loss_and_grad (loss_fn , model , config , data , dropout_rng )
301+ else :
302+ if config .optimizer_memory_host_offload :
303+ # Need impl on NNX
304+ pass
305+ # if config.use_dpo:
306+ # reference_params = jax.device_put(
307+ # reference_params,
308+ # max_utils.with_memory_kind(reference_params_sharding, "device"),
309+ # )
310+ # extra_dpo_args = [reference_params]
311+ if config .shard_optimizer_over_data :
312+ # Need impl on NNX
313+ pass
314+ # params = jax.tree.map(
315+ # functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode),
316+ # params,
317+ # params_shardings,
318+ # )
319+ grad_fn = nnx .value_and_grad (loss_fn , argnums = 0 , has_aux = True )
320+ (loss , aux ), raw_grads = grad_fn (model , config , data , dropout_rng , is_train = True )
292321
293322 # Cast gradients to configured dtype before clipping / accumulation
294323 raw_grads = jax .tree .map (
295324 lambda x : x .astype (config .grad_dtype ) if x .dtype == jnp .float32 else x ,
296325 raw_grads ,
297326 )
327+ intermediate_outputs = aux ["intermediate_outputs" ]
328+ total_weights = aux ["total_weights" ]
329+ moe_lb_loss = aux ["moe_lb_loss" ]
330+ indexer_loss = aux ["indexer_loss" ]
331+ z_loss = aux ["z_loss" ]
332+ moe_bias_updates = aux ["moe_bias_updates" ]
333+ mtp_loss = aux ["mtp_loss" ]
298334
299335 # Gradient clipping (implemented directly to avoid Linen TrainState dependency)
300336 if config .gradient_clipping_threshold > 0 :
301337 clip_tx = optax .clip_by_global_norm (config .gradient_clipping_threshold )
302338 grads , _ = clip_tx .update (raw_grads , clip_tx .init (raw_grads ), None )
303339 else :
304340 grads = raw_grads
341+ if config .optimizer_memory_host_offload :
342+ # Need impl on NNX
343+ pass
344+ # state = state.replace(
345+ # opt_state=jax.device_put(
346+ # state.opt_state,
347+ # jax.tree_util.tree_map(
348+ # lambda x: x.with_memory_kind(kind="device"),
349+ # state_mesh_shardings.opt_state,
350+ # ),
351+ # )
352+ # )
353+ # Move all parameters to device before optimizer update
354+ if config .parameter_memory_host_offload :
355+ max_logging .log ("\n Moving all parameters to device before optimizer update" )
356+ # Need impl on NNX
357+ # def move(path, value):
358+ # max_logging.log(f"train.py: Moving f{path} to device")
359+ # return value.with_memory_kind(kind="device")
360+
361+ # state = state.replace(
362+ # params=jax.device_put(
363+ # state.params,
364+ # jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params),
365+ # )
366+ # )
305367
306368 # NNX 0.11+: update takes (model, grads) explicitly.
307369 optimizer .update (model , grads )
308370
309371 new_model_state = nnx .state (model )
310372 new_opt_state = nnx .state (optimizer )
311373
374+ # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family
375+ if config .routed_bias and config .routed_bias_update_rate > 0.0 and moe_bias_updates is not None :
376+ # Need impl on NNX
377+ pass
378+ # target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias")
379+ # Flax 'sow' returns a tuple, so we take the first element [0].
380+ # Updates the shape to be aligned with state.
381+ # moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose()
382+ # new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates)
383+
312384 scalar_metrics = {
313385 "learning/loss" : loss ,
314- "learning/z_loss" : aux [ " z_loss" ] ,
315- "learning/total_weights " : aux [ "total_weights" ] ,
316- "learning/grad_norm " : max_utils . l2norm_pytree ( grads ) ,
317- "learning/raw_grad_norm " : max_utils . l2norm_pytree ( raw_grads ) ,
318- "learning/param_norm " : max_utils . l2norm_pytree ( nnx . state ( model , nnx . Param )) ,
386+ "learning/z_loss" : z_loss ,
387+ "learning/moe_lb_loss " : moe_lb_loss ,
388+ "learning/indexer_loss " : indexer_loss ,
389+ "learning/mtp_loss " : mtp_loss ,
390+ "learning/total_weights " : total_weights ,
319391 }
320- metrics = {"scalar" : scalar_metrics , "scalars" : {}}
392+ if config .use_qk_clip :
393+ # Apply QK-Clip
394+ # Need impl on NNX
395+ pass
396+ # new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config)
397+
398+ # Report max_logits metric
399+ # global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs)
400+ # if global_max_logit is not None:
401+ # scalar_metrics["learning/max_logits"] = global_max_logit
402+
403+ if not config .optimizer_memory_host_offload :
404+ scalar_metrics ["learning/grad_norm" ] = max_utils .l2norm_pytree (grads )
405+ scalar_metrics ["learning/raw_grad_norm" ] = max_utils .l2norm_pytree (raw_grads )
406+ scalar_metrics ["learning/param_norm" ] = max_utils .l2norm_pytree (nnx .state (model , nnx .Param ))
407+ if config .use_dpo :
408+ scalar_metrics ["learning/dpo_reward_accuracy" ] = aux ["reward_accuracy" ]
409+ metrics = {
410+ "scalar" : scalar_metrics ,
411+ "scalars" : {},
412+ }
413+
414+ if config .record_internal_nn_metrics :
415+ record_activation_metrics (metrics , intermediate_outputs , config )
416+
417+ if config .use_dpo :
418+ # Need impl on NNX
419+ pass
420+ # new_state = _merge_dpo_state(new_state, reference_params)
321421 return (new_model_state , new_opt_state ), metrics
322422
323423
@@ -351,6 +451,7 @@ def eval_step(
351451 z_loss = aux ["z_loss" ]
352452 total_weights = aux ["total_weights" ]
353453 moe_lb_loss = aux ["moe_lb_loss" ]
454+ indexer_loss = aux ["indexer_loss" ]
354455 mtp_loss = aux ["mtp_loss" ]
355456 metrics = {
356457 "scalar" : {
@@ -359,6 +460,7 @@ def eval_step(
359460 "evaluation/total_loss" : total_loss ,
360461 "evaluation/total_weights" : total_weights ,
361462 "evaluation/moe_lb_loss" : moe_lb_loss ,
463+ "evaluation/indexer_loss" : indexer_loss ,
362464 "evaluation/mtp_loss" : mtp_loss ,
363465 "evaluation/mtp_acceptance_rate_percent" : mtp_acceptance_rate ,
364466 },
@@ -416,8 +518,8 @@ def _create_and_shard_optimizer(model: nnx.Module, config, mesh: Mesh):
416518 _ , opt_state = nnx .split (optimizer )
417519
418520 @functools .partial (jax .jit , out_shardings = (model_shardings , opt_shardings ))
419- def shard_states (ms , os ):
420- return ms , os
521+ def shard_states (mshard , oshard ):
522+ return mshard , oshard
421523
422524 with mesh :
423525 model_state , opt_state = shard_states (model_state , opt_state )
@@ -609,7 +711,9 @@ def train_loop(config, recorder, state=None):
609711 shaped_batch = maxtext_utils .get_shaped_batch (config )
610712 init_rng = jax .random .PRNGKey (config .init_weights_seed )
611713 example_rng = jax .jit (jax .random .fold_in )(init_rng , 0 )
612- if config .compiled_trainstep_file == "" :
714+ # Need imple below func on NNX
715+ # maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (model_state, opt_state, shaped_batch, example_rng))
716+ if config .compiled_trainstep_file == "" : # compile only when there is no pre-compiled file loaded
613717 compiled = p_train_step .lower (model_state , opt_state , shaped_batch , example_rng ).compile ()
614718 compiled_stats = compiled .memory_analysis ()
615719 max_utils .print_compiled_memory_stats (compiled_stats )
@@ -625,14 +729,14 @@ def train_loop(config, recorder, state=None):
625729 _job_completed_gracefully = False
626730 try :
627731 last_step_completion = datetime .datetime .now ()
732+ max_logging .info (f"Entering train loop from start_step={ start_step } " )
628733
629734 for step in np .arange (start_step , config .steps ):
630735 prof .maybe_activate_profiler (step , opt_state )
631736
632737 with jax .profiler .StepTraceAnnotation ("train" , step_num = step ):
633738 example_batch = data_loader .load_next_batch (rampup_manager = rampup_manager )
634739 nextrng = jax .jit (jax .random .fold_in )(init_rng , step )
635-
636740 with maybe_record_goodput (recorder , GoodputEvent .STEP , step ):
637741 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
638742 (model_state , opt_state ), metrics = p_train_step (model_state , opt_state , example_batch , nextrng )
@@ -650,15 +754,18 @@ def train_loop(config, recorder, state=None):
650754 and (step + 1 ) % config .eval_interval == 0
651755 ):
652756 assert eval_data_iterator
757+ # Explicitly reset the eval iterator and counters before starting the eval loop
653758 eval_data_iterator .reset ()
654759 metric_logger .reset_eval_metrics ()
760+
655761 eval_step_count = 0
656762 for eval_batch in eval_data_iterator :
657763 if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
658764 break
659765 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
660766 eval_metrics = p_eval_step (model_state , eval_batch , nextrng )
661767 metric_logger .record_eval_metrics (step , metrics = eval_metrics )
768+ max_logging .log (f"Completed eval step { eval_step_count } " )
662769 eval_step_count += 1
663770
664771 metric_logger .record_eval_metrics (step , eval_step_count = eval_step_count )
@@ -679,6 +786,7 @@ def train_loop(config, recorder, state=None):
679786 checkpoint_manager , model_state , opt_state , config , data_iterator , step = int (config .steps - 1 )
680787 )
681788 if checkpoint_manager is not None :
789+ # in case the last checkpoint_period checkpoint is still in progress
682790 checkpoint_manager .wait_until_finished ()
683791
684792 _job_completed_gracefully = True
@@ -728,8 +836,10 @@ def initialize(argv: Sequence[str]):
728836 if config .use_vertex_tensorboard or os .environ .get ("UPLOAD_DATA_TO_TENSORBOARD" ):
729837 vertex_tensorboard_manager .configure_vertex_tensorboard (config )
730838
839+ # Create the Goodput recorder
731840 recorder = create_goodput_recorder (config )
732841
842+ # Stack traces configurations
733843 debug_config = debug_configuration .DebugConfig (
734844 stack_trace_config = stack_trace_configuration .StackTraceConfig (
735845 collect_stack_trace = config .collect_stack_trace ,
@@ -742,13 +852,20 @@ def initialize(argv: Sequence[str]):
742852
743853
744854def run (config , recorder , diagnostic_config ):
745- """Run the NNX training job."""
855+ """Run the NNX training job.
856+
857+ In decoupled mode (DECOUPLE_GCLOUD=TRUE) cloud diagnostics may be stubbed; if so, skip wrapping.
858+ """
859+ # Use nullcontext when diagnostics are stubbed or in decoupled mode
746860 diagnostics_context = (
747861 contextlib .nullcontext ()
748862 if is_decoupled () or getattr (diagnostic , "__class__" , None ).__name__ == "_StubDiag"
749863 else diagnostic .diagnose (diagnostic_config )
750864 )
751865
866+ if is_decoupled () or getattr (diagnostic , "__class__" , None ).__name__ == "_StubDiag" :
867+ max_logging .log ("[DECOUPLED NO-OP] skipping cloud diagnostics wrapper." )
868+
752869 with (
753870 diagnostics_context ,
754871 max_utils .maybe_get_transformer_engine_context (config ),
0 commit comments