Skip to content

Commit 6875da8

Browse files
committed
fix: update calculation of num stages
1 parent f711938 commit 6875da8

2 files changed

Lines changed: 128 additions & 79 deletions

File tree

src/MaxText/layers/decoders.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,16 +762,24 @@ def get_pipeline_stage_module(self, decoder_block_classes):
762762
# Pre-fetch policy to pass to stage
763763
policy = self.get_remat_policy()
764764

765+
# Calculate layers dynamically
766+
num_stages = cfg.ici_pipeline_parallelism * cfg.dcn_pipeline_parallelism
767+
if num_stages <= 0: num_stages = 1 # Safety fallback
768+
769+
# Ensure total PP layers are evenly divisible, or handle remainder logic if necessary
770+
# (MaxText usually assumes even division or handles it via config)
771+
calculated_layers_per_stage = cfg.pipeline_parallel_layers // num_stages
772+
765773
def stage_factory(rngs_key):
766774
return PipelineStageBlock(
767775
config=cfg,
768776
mesh=self.mesh,
769777
quant=self.quant,
770778
model_mode=self.model_mode,
771-
num_layers=cfg.num_layers_per_pipeline_stage,
779+
# Use the calculated value instead of config.num_layers_per_pipeline_stage
780+
num_layers=calculated_layers_per_stage,
772781
layer_class=base_stage_cls,
773782
remat_policy=policy,
774-
scan_axis_name="layers_per_stage",
775783
rngs=rngs_key
776784
)
777785
return stage_factory

src/MaxText/layers/pipeline_test.py

Lines changed: 118 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from flax import nnx
2828
from flax import linen as nn
2929
from MaxText.layers import nnx_wrappers
30+
from MaxText import maxtext_utils
31+
3032

3133
from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode
3234
from MaxText.sharding import (
@@ -400,14 +402,35 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p
400402
return pipeline_weights
401403

402404
def get_weight_sharding(self):
403-
state = nnx.state(self.layers)
404-
def get_spec(leaf):
405-
if hasattr(leaf, "sharding") and isinstance(leaf.sharding, PartitionSpec):
406-
return leaf.sharding
407-
return PartitionSpec()
408-
409-
partition_spec_tree = jax.tree.map(get_spec, state)
410-
return {"params": partition_spec_tree}
405+
"""Returns the PartitionSpec tree for the model weights, prepending 'stage' axis."""
406+
flat_specs = {}
407+
408+
# Iterate over the graph to access the actual Variable objects (which hold metadata)
409+
# rather than just the values.
410+
for path, var in nnx.iter_graph(self):
411+
if isinstance(var, nnx.Param):
412+
# 1. Get the inner sharding spec defined by the layer (e.g. {'embed', 'vocab'})
413+
# If no sharding is defined, it defaults to None (fully replicated inner).
414+
inner_spec = getattr(var, 'sharding', None)
415+
416+
# 2. Normalize inner_spec to a tuple/PartitionSpec
417+
if inner_spec is None:
418+
inner_spec = PartitionSpec() # empty tuple
419+
420+
# 3. Prepend the "stage" axis.
421+
# We know 'self.layers' is vmapped over the 'stage' axis.
422+
# All parameters inside 'self.layers' must have this leading axis sharded.
423+
if path[0] == 'layers':
424+
new_spec = PartitionSpec("stage", *inner_spec)
425+
flat_specs[path] = new_spec
426+
else:
427+
# Handle non-layer parameters if any (unlikely in this Pipeline design)
428+
flat_specs[path] = inner_spec
429+
430+
# 4. Reconstruct the nested structure matching the parameters
431+
nested_specs = nnx.State(flat_specs).to_pure_dict()
432+
433+
return {"params": nested_specs}
411434

412435
def get_functional_stage_fn(self):
413436
"""Returns pure (weights, inputs...) -> (output, new_state)"""
@@ -456,7 +479,10 @@ def run_one_iteration(
456479
# Vmap over stages (axis 0)
457480
# output: (stages_out, updated_weights)
458481
vmapped_stage_fn = jax.vmap(
459-
stage_fn_pure, in_axes=(0, 0, 0, 0, None, None), out_axes=(0, 0)
482+
stage_fn_pure,
483+
in_axes=(0, 0, 0, 0, None, None),
484+
out_axes=(0, 0),
485+
spmd_axis_name=self.spmd_axis_name
460486
)
461487

462488
stages_output, updated_stage_weights = vmapped_stage_fn(
@@ -518,7 +544,7 @@ def all_gather_over_fsdp(self, variables, logical_partition_spec):
518544
variables,
519545
physical_partition_spec_no_fsdp,
520546
)
521-
547+
522548
def __call__(
523549
self,
524550
inputs: jnp.ndarray,
@@ -528,79 +554,94 @@ def __call__(
528554
model_mode=MODEL_MODE_TRAIN,
529555
logical_partition_spec=None,
530556
) -> jnp.ndarray:
531-
532-
inputs = inputs.reshape(
533-
(
534-
self.config.num_pipeline_microbatches,
535-
self.pipeline_microbatch_size,
536-
self.config.max_target_length,
537-
self.config.emb_dim,
538-
),
539-
out_sharding=self.input_sharding,
540-
)
541-
542-
ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None))
543-
if positions is not None:
544-
positions = self._maybe_shard_with_name(positions, ag_sharding)
545-
positions = positions.reshape(
546-
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
547-
)
548-
549-
if segment_ids is not None:
550-
segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding)
551-
segment_ids = segment_ids.reshape(
552-
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
557+
"""The main method that maps the series of decoder layer inputs to final layer outputs."""
558+
with self.mesh:
559+
# 1. Reshape inputs to [microbatches, microbatch_size, seq_len, embed_dim]
560+
inputs = inputs.reshape(
561+
(
562+
self.config.num_pipeline_microbatches,
563+
self.pipeline_microbatch_size,
564+
self.config.max_target_length,
565+
self.config.emb_dim,
566+
),
567+
out_sharding=self.input_sharding,
553568
)
554569

555-
loop_state = self.init_states(inputs)
556-
557-
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
558-
real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats
559-
total_iterations = real_iterations + bubble_iterations
560-
561-
variables = nnx.state(self.layers)
562-
563-
if self.config.pipeline_fsdp_ag_once:
564-
all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec)
565-
else:
566-
all_pipeline_weights = variables
567-
568-
logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec)
570+
# 2. Handle Positions and Segment IDs (All Gather if needed)
571+
ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None))
572+
573+
if positions is not None:
574+
positions = self._maybe_shard_with_name(positions, ag_sharding)
575+
positions = positions.reshape(
576+
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
577+
)
569578

570-
def step_fn(carry, _):
571-
curr_loop_state, curr_weights = carry
572-
573-
new_loop_state, new_weights = self.run_one_iteration(
574-
curr_loop_state,
575-
curr_weights,
576-
positions,
577-
segment_ids,
578-
deterministic,
579-
model_mode,
580-
logical_partition_spec=logical_partition_spec,
579+
if segment_ids is not None:
580+
segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding)
581+
segment_ids = segment_ids.reshape(
582+
(self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)
581583
)
582-
return (new_loop_state, new_weights), None
583584

584-
if self.config.set_remat_policy_on_pipeline_iterations:
585-
step_fn = jax.checkpoint(step_fn, policy=self.get_pipeline_remat_policy())
585+
# 3. Initialize Pipeline State Buffers
586+
loop_state = self.init_states(inputs)
586587

587-
if self.config.scan_pipeline_iterations:
588-
scan_xs = jnp.arange(total_iterations)
589-
(loop_state, final_weights), _ = jax.lax.scan(step_fn, (loop_state, all_pipeline_weights), scan_xs)
590-
else:
591-
curr_weights = all_pipeline_weights
592-
for _ in range(total_iterations):
593-
(loop_state, curr_weights), _ = step_fn((loop_state, curr_weights), None)
594-
final_weights = curr_weights
588+
bubble_iterations = self.forwarding_delay * (self.num_stages - 1)
589+
real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats
590+
total_iterations = real_iterations + bubble_iterations
595591

596-
nnx.update(self.layers, final_weights)
592+
# 4. Prepare Weights (Capture once)
593+
# We treat weights as constant for the duration of the pipeline loop (Forward Pass).
594+
# This matches Linen's 'variable_broadcast' semantics and prevents OOM.
595+
variables = nnx.state(self.layers)
597596

598-
final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"])
599-
600-
final_output = jnp.reshape(
601-
final_output,
602-
(self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim),
603-
out_sharding=self.output_sharding,
604-
)
597+
if self.config.pipeline_fsdp_ag_once:
598+
all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec)
599+
else:
600+
all_pipeline_weights = variables
601+
602+
logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec)
603+
604+
# 5. Define the Step Function
605+
def step_fn(loop_state, _):
606+
# We close over 'all_pipeline_weights', treating them as constants.
607+
# This tells XLA not to allocate new buffers for weights at every step.
608+
new_loop_state, _ = self.run_one_iteration(
609+
loop_state,
610+
all_pipeline_weights,
611+
positions,
612+
segment_ids,
613+
deterministic,
614+
model_mode,
615+
logical_partition_spec=logical_partition_spec,
616+
)
617+
# We discard the second return value (updated_stage_weights/metrics) here
618+
# to ensure the scan loop stays efficient and memory-bound.
619+
return new_loop_state, None
620+
621+
# 6. Apply Rematerialization (Gradient Checkpointing)
622+
if self.config.set_remat_policy_on_pipeline_iterations:
623+
prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config)
624+
step_fn = jax.checkpoint(step_fn, policy=self.get_pipeline_remat_policy(),prevent_cse=prevent_cse)
625+
626+
# 7. Execute the Loop
627+
if self.config.scan_pipeline_iterations:
628+
# Use jax.lax.scan for compilation efficiency
629+
scan_xs = jnp.arange(total_iterations)
630+
# Pass ONLY loop_state as carry. Weights are implicitly broadcasted via closure.
631+
loop_state, _ = jax.lax.scan(step_fn, loop_state, scan_xs)
632+
else:
633+
# Standard loop (for debugging or specific configs)
634+
for _ in range(total_iterations):
635+
loop_state, _ = step_fn(loop_state, None)
636+
637+
# 8. Post-process Outputs
638+
# The final output is located in the state_io buffer, potentially permuted.
639+
final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"])
640+
641+
final_output = jnp.reshape(
642+
final_output,
643+
(self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim),
644+
out_sharding=self.output_sharding,
645+
)
605646

606-
return final_output
647+
return final_output

0 commit comments

Comments
 (0)