2727from flax import nnx
2828from flax import linen as nn
2929from MaxText .layers import nnx_wrappers
30+ from MaxText import maxtext_utils
31+
3032
3133from MaxText .common_types import Config , MODEL_MODE_TRAIN , EP_AS_CONTEXT , ShardMode
3234from MaxText .sharding import (
@@ -400,14 +402,35 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p
400402 return pipeline_weights
401403
402404 def get_weight_sharding (self ):
403- state = nnx .state (self .layers )
404- def get_spec (leaf ):
405- if hasattr (leaf , "sharding" ) and isinstance (leaf .sharding , PartitionSpec ):
406- return leaf .sharding
407- return PartitionSpec ()
408-
409- partition_spec_tree = jax .tree .map (get_spec , state )
410- return {"params" : partition_spec_tree }
405+ """Returns the PartitionSpec tree for the model weights, prepending 'stage' axis."""
406+ flat_specs = {}
407+
408+ # Iterate over the graph to access the actual Variable objects (which hold metadata)
409+ # rather than just the values.
410+ for path , var in nnx .iter_graph (self ):
411+ if isinstance (var , nnx .Param ):
412+ # 1. Get the inner sharding spec defined by the layer (e.g. {'embed', 'vocab'})
413+ # If no sharding is defined, it defaults to None (fully replicated inner).
414+ inner_spec = getattr (var , 'sharding' , None )
415+
416+ # 2. Normalize inner_spec to a tuple/PartitionSpec
417+ if inner_spec is None :
418+ inner_spec = PartitionSpec () # empty tuple
419+
420+ # 3. Prepend the "stage" axis.
421+ # We know 'self.layers' is vmapped over the 'stage' axis.
422+ # All parameters inside 'self.layers' must have this leading axis sharded.
423+ if path [0 ] == 'layers' :
424+ new_spec = PartitionSpec ("stage" , * inner_spec )
425+ flat_specs [path ] = new_spec
426+ else :
427+ # Handle non-layer parameters if any (unlikely in this Pipeline design)
428+ flat_specs [path ] = inner_spec
429+
430+ # 4. Reconstruct the nested structure matching the parameters
431+ nested_specs = nnx .State (flat_specs ).to_pure_dict ()
432+
433+ return {"params" : nested_specs }
411434
412435 def get_functional_stage_fn (self ):
413436 """Returns pure (weights, inputs...) -> (output, new_state)"""
@@ -456,7 +479,10 @@ def run_one_iteration(
456479 # Vmap over stages (axis 0)
457480 # output: (stages_out, updated_weights)
458481 vmapped_stage_fn = jax .vmap (
459- stage_fn_pure , in_axes = (0 , 0 , 0 , 0 , None , None ), out_axes = (0 , 0 )
482+ stage_fn_pure ,
483+ in_axes = (0 , 0 , 0 , 0 , None , None ),
484+ out_axes = (0 , 0 ),
485+ spmd_axis_name = self .spmd_axis_name
460486 )
461487
462488 stages_output , updated_stage_weights = vmapped_stage_fn (
@@ -518,7 +544,7 @@ def all_gather_over_fsdp(self, variables, logical_partition_spec):
518544 variables ,
519545 physical_partition_spec_no_fsdp ,
520546 )
521-
547+
522548 def __call__ (
523549 self ,
524550 inputs : jnp .ndarray ,
@@ -528,79 +554,94 @@ def __call__(
528554 model_mode = MODEL_MODE_TRAIN ,
529555 logical_partition_spec = None ,
530556 ) -> jnp .ndarray :
531-
532- inputs = inputs .reshape (
533- (
534- self .config .num_pipeline_microbatches ,
535- self .pipeline_microbatch_size ,
536- self .config .max_target_length ,
537- self .config .emb_dim ,
538- ),
539- out_sharding = self .input_sharding ,
540- )
541-
542- ag_sharding = jax .sharding .NamedSharding (self .mesh , jax .sharding .PartitionSpec (None , None ))
543- if positions is not None :
544- positions = self ._maybe_shard_with_name (positions , ag_sharding )
545- positions = positions .reshape (
546- (self .config .num_pipeline_microbatches , self .pipeline_microbatch_size , self .config .max_target_length )
547- )
548-
549- if segment_ids is not None :
550- segment_ids = self ._maybe_shard_with_name (segment_ids , ag_sharding )
551- segment_ids = segment_ids .reshape (
552- (self .config .num_pipeline_microbatches , self .pipeline_microbatch_size , self .config .max_target_length )
557+ """The main method that maps the series of decoder layer inputs to final layer outputs."""
558+ with self .mesh :
559+ # 1. Reshape inputs to [microbatches, microbatch_size, seq_len, embed_dim]
560+ inputs = inputs .reshape (
561+ (
562+ self .config .num_pipeline_microbatches ,
563+ self .pipeline_microbatch_size ,
564+ self .config .max_target_length ,
565+ self .config .emb_dim ,
566+ ),
567+ out_sharding = self .input_sharding ,
553568 )
554569
555- loop_state = self .init_states (inputs )
556-
557- bubble_iterations = self .forwarding_delay * (self .num_stages - 1 )
558- real_iterations = self .config .num_pipeline_microbatches * self .config .num_pipeline_repeats
559- total_iterations = real_iterations + bubble_iterations
560-
561- variables = nnx .state (self .layers )
562-
563- if self .config .pipeline_fsdp_ag_once :
564- all_pipeline_weights = self .all_gather_over_fsdp (variables , logical_partition_spec )
565- else :
566- all_pipeline_weights = variables
567-
568- logical_partition_spec = self .get_logical_spec_repeats_removed (logical_partition_spec )
570+ # 2. Handle Positions and Segment IDs (All Gather if needed)
571+ ag_sharding = jax .sharding .NamedSharding (self .mesh , jax .sharding .PartitionSpec (None , None ))
572+
573+ if positions is not None :
574+ positions = self ._maybe_shard_with_name (positions , ag_sharding )
575+ positions = positions .reshape (
576+ (self .config .num_pipeline_microbatches , self .pipeline_microbatch_size , self .config .max_target_length )
577+ )
569578
570- def step_fn (carry , _ ):
571- curr_loop_state , curr_weights = carry
572-
573- new_loop_state , new_weights = self .run_one_iteration (
574- curr_loop_state ,
575- curr_weights ,
576- positions ,
577- segment_ids ,
578- deterministic ,
579- model_mode ,
580- logical_partition_spec = logical_partition_spec ,
579+ if segment_ids is not None :
580+ segment_ids = self ._maybe_shard_with_name (segment_ids , ag_sharding )
581+ segment_ids = segment_ids .reshape (
582+ (self .config .num_pipeline_microbatches , self .pipeline_microbatch_size , self .config .max_target_length )
581583 )
582- return (new_loop_state , new_weights ), None
583584
584- if self . config . set_remat_policy_on_pipeline_iterations :
585- step_fn = jax . checkpoint ( step_fn , policy = self .get_pipeline_remat_policy () )
585+ # 3. Initialize Pipeline State Buffers
586+ loop_state = self .init_states ( inputs )
586587
587- if self .config .scan_pipeline_iterations :
588- scan_xs = jnp .arange (total_iterations )
589- (loop_state , final_weights ), _ = jax .lax .scan (step_fn , (loop_state , all_pipeline_weights ), scan_xs )
590- else :
591- curr_weights = all_pipeline_weights
592- for _ in range (total_iterations ):
593- (loop_state , curr_weights ), _ = step_fn ((loop_state , curr_weights ), None )
594- final_weights = curr_weights
588+ bubble_iterations = self .forwarding_delay * (self .num_stages - 1 )
589+ real_iterations = self .config .num_pipeline_microbatches * self .config .num_pipeline_repeats
590+ total_iterations = real_iterations + bubble_iterations
595591
596- nnx .update (self .layers , final_weights )
592+ # 4. Prepare Weights (Capture once)
593+ # We treat weights as constant for the duration of the pipeline loop (Forward Pass).
594+ # This matches Linen's 'variable_broadcast' semantics and prevents OOM.
595+ variables = nnx .state (self .layers )
597596
598- final_output = self .permute_output_micro_per_stage_dim (loop_state ["state_io" ])
599-
600- final_output = jnp .reshape (
601- final_output ,
602- (self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ),
603- out_sharding = self .output_sharding ,
604- )
597+ if self .config .pipeline_fsdp_ag_once :
598+ all_pipeline_weights = self .all_gather_over_fsdp (variables , logical_partition_spec )
599+ else :
600+ all_pipeline_weights = variables
601+
602+ logical_partition_spec = self .get_logical_spec_repeats_removed (logical_partition_spec )
603+
604+ # 5. Define the Step Function
605+ def step_fn (loop_state , _ ):
606+ # We close over 'all_pipeline_weights', treating them as constants.
607+ # This tells XLA not to allocate new buffers for weights at every step.
608+ new_loop_state , _ = self .run_one_iteration (
609+ loop_state ,
610+ all_pipeline_weights ,
611+ positions ,
612+ segment_ids ,
613+ deterministic ,
614+ model_mode ,
615+ logical_partition_spec = logical_partition_spec ,
616+ )
617+ # We discard the second return value (updated_stage_weights/metrics) here
618+ # to ensure the scan loop stays efficient and memory-bound.
619+ return new_loop_state , None
620+
621+ # 6. Apply Rematerialization (Gradient Checkpointing)
622+ if self .config .set_remat_policy_on_pipeline_iterations :
623+ prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
624+ step_fn = jax .checkpoint (step_fn , policy = self .get_pipeline_remat_policy (),prevent_cse = prevent_cse )
625+
626+ # 7. Execute the Loop
627+ if self .config .scan_pipeline_iterations :
628+ # Use jax.lax.scan for compilation efficiency
629+ scan_xs = jnp .arange (total_iterations )
630+ # Pass ONLY loop_state as carry. Weights are implicitly broadcasted via closure.
631+ loop_state , _ = jax .lax .scan (step_fn , loop_state , scan_xs )
632+ else :
633+ # Standard loop (for debugging or specific configs)
634+ for _ in range (total_iterations ):
635+ loop_state , _ = step_fn (loop_state , None )
636+
637+ # 8. Post-process Outputs
638+ # The final output is located in the state_io buffer, potentially permuted.
639+ final_output = self .permute_output_micro_per_stage_dim (loop_state ["state_io" ])
640+
641+ final_output = jnp .reshape (
642+ final_output ,
643+ (self .config .micro_batch_size_to_train_on , self .config .max_target_length , self .config .emb_dim ),
644+ out_sharding = self .output_sharding ,
645+ )
605646
606- return final_output
647+ return final_output
0 commit comments