diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index c96ec08c8d..cca1e3efc5 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -23,6 +23,7 @@ import jax import jax.numpy as jnp + from flax import linen as nn from flax import nnx from flax.nnx import wrappers as nnx_wrappers @@ -63,6 +64,8 @@ from maxtext.multimodal import utils as mm_utils from maxtext.utils import max_logging, max_utils, maxtext_utils, sharding from maxtext.utils.sharding import create_sharding +from maxtext.layers.pipeline import create_nnx_pipeline + # ------------------------------------------------------------------------------ # The network: Decoder Definitions @@ -220,7 +223,7 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): """Process deepstack visual embeddings by adding them to hidden states at visual token positions. Args: - hidden_states: [batch, seq_len, hidden_dim] decoder hidden states + hidden_states:[batch, seq_len, hidden_dim] decoder hidden states bidirectional_mask: [batch, seq_len] boolean mask marking visual token positions visual_embeds: [batch, num_visual_tokens, hidden_dim] visual features from encoder layer @@ -235,12 +238,90 @@ def deepstack_process(hidden_states, bidirectional_mask, visual_embeds): # Gather visual tokens: for each position, get the corresponding visual token batch_idx = jnp.arange(hidden_states.shape[0])[:, jnp.newaxis] # [batch, 1] visual_embeds_scattered = visual_embeds[batch_idx, visual_token_idx, :] # [batch, seq_len, hidden] - # Only add where mask is True: hidden_states += visual_embeds * mask hidden_states = hidden_states + visual_embeds_scattered * mask_expanded return hidden_states +class NNXSequentialPipelineStage(nnx.Module): + """Sequential unscanned series of decoder layers formatted for a single pipeline stage.""" + + def __init__( + self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs + ): + self.config = config + self.scan_layers = config.scan_layers + self.num_layers = num_layers + # Dynamically assign layers with explicit string names to ensure correct PyTree paths (layers_0) + for i in range(num_layers): + layer = layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rngs) + setattr(self, f"layers_{i}", layer) + + def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs): + for i in range(self.num_layers): + layer = getattr(self, f"layers_{i}") + out = layer(inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs) + inputs = out[0] if isinstance(out, tuple) else out + if self.scan_layers: + return inputs, None + return inputs + + +class NNXScannedPipelineStage(nnx.Module): + """Scanned block of decoder layers formatted for a single pipeline stage.""" + + def __init__( + self, layer_cls, num_layers: int, config: Config, mesh: Mesh, quant: Quant, model_mode: str, *, rngs: nnx.Rngs + ): + self.config = config + + def create_layer_fn(rng): + return layer_cls(config=config, mesh=mesh, quant=quant, model_mode=model_mode, rngs=rng) + + # Workaround for Deepseek MTP test failure. + # TODO: Handle this properly. + try: + forked_rngs = rngs.fork(split=num_layers) + except: # pylint: disable=bare-except + forked_rngs = rngs + + out_axes = nnx.StateAxes({nnx.Param: config.param_scan_axis, ...: 0}) + self.scanned_layers = nnx.vmap( + create_layer_fn, + in_axes=0, + out_axes=out_axes, + axis_name="layers_per_stage", + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + )(forked_rngs) + + def __call__(self, inputs, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs): + graphdef, params, state = nnx.split(self.scanned_layers, nnx.Param, ...) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + def layer_fn(carry, scanned_vars): + current_params, current_state = scanned_vars + layer = nnx.merge(graphdef, current_params, current_state) + layer_out = layer(carry, decoder_segment_ids, decoder_positions, deterministic, model_mode, **kwargs) + new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out + return new_carry, nnx.state(layer) + + final_carry, scanned_state = jax.lax.scan(layer_fn, inputs, (params, state)) + + if scan_axis != 0: + scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) + scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) + scanned_state = nnx.State.merge(scanned_params, scanned_other) + + self.scanned_layers = nnx.merge(graphdef, scanned_state) + + if self.config.scan_layers: + return final_carry, None + return final_carry + + class NNXDecoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture, using NNX.""" @@ -301,78 +382,143 @@ def __init__( self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 - if self.config.scan_layers: - if self.is_deepseek: - assert len(decoder_block_classes) == 2 - dense_cls, moe_cls = decoder_block_classes - - num_dense = config.first_num_dense_layers - self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) + if config.using_pipeline_parallelism: - num_moe = config.num_decoder_layers - config.first_num_dense_layers + def stage_factory(rngs): + return self._get_pipeline_stage_module(decoder_block_classes, rngs) - self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) - elif self.is_gemma3: - attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) - scan_length = config.num_decoder_layers // attention_pattern_length - num_remaining_layers = config.num_decoder_layers % attention_pattern_length - layer_kwargs = {"num_of_layers": attention_pattern_length} - - rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - - RemattedGemma3Block = gemma3.Gemma3ScannableBlock - - if scan_length > 0: - self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) - self.layers_remainder = RemattedGemma3Block( - config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs - ) # pytype: disable=wrong-keyword-args - else: - layer_cls = decoder_block_classes[0] - num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if config.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - - self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) - else: - self.layers = nnx.List([]) + self.pipeline_module = create_nnx_pipeline( + config=config, + stage_factory=stage_factory, + mesh=mesh, + remat_policy=self.get_remat_policy(), + rngs=rngs, + ) if self.is_deepseek: + assert len(decoder_block_classes) == 2 dense_cls, moe_cls = decoder_block_classes - for i in range(config.first_num_dense_layers): - self._create_and_register_layer(dense_cls, rngs, "dense_layer", i) - for i in range(config.num_decoder_layers - config.first_num_dense_layers): - self._create_and_register_layer(moe_cls, rngs, "moe_layer", i) + if config.scan_layers: + self.dense_layers = self._create_scanned_layers( + dense_cls, length=config.first_num_dense_layers, metadata_axis_name="dense_layers", rngs=rngs + ) + num_moe_outside = (config.num_decoder_layers - config.first_num_dense_layers) - config.pipeline_parallel_layers + if num_moe_outside > 0: + self.moe_layers_outside_pipeline = self._create_scanned_layers( + moe_cls, length=num_moe_outside, metadata_axis_name="moe_layers", rngs=rngs + ) + else: + self.num_dense_layers = config.first_num_dense_layers + for i in range(self.num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layers", i) + + self.num_moe_outside_pipeline = ( + config.num_decoder_layers - config.first_num_dense_layers + ) - config.pipeline_parallel_layers + if self.num_moe_outside_pipeline > 0: + for i in range(self.num_moe_outside_pipeline): + self._create_and_register_layer(moe_cls, rngs, "moe_layers_outside_pipeline", i) + else: - layer_cls = decoder_block_classes[0] + remaining_layers = config.num_decoder_layers - config.pipeline_parallel_layers + if remaining_layers > 0: + base_cls = decoder_block_classes[0] + if config.scan_layers: + self.layers_outside_pipeline = self._create_scanned_layers( + base_cls, length=remaining_layers, metadata_axis_name="layers", rngs=rngs + ) + else: + self.num_layers_outside_pipeline = remaining_layers + for i in range(self.num_layers_outside_pipeline): + self._create_and_register_layer(base_cls, rngs, "layers_outside_pipeline", i) - for lyr in range(config.num_decoder_layers): + else: + # Setup for Standard Non-Pipeline Execution + if self.config.scan_layers: + if self.is_deepseek: + assert len(decoder_block_classes) == 2 + dense_cls, moe_cls = decoder_block_classes + self.dense_layers = self._create_scanned_layers( + dense_cls, length=config.first_num_dense_layers, metadata_axis_name="dense_layers", rngs=rngs + ) + num_moe = config.num_decoder_layers - config.first_num_dense_layers + self.moe_layers = self._create_scanned_layers( + moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs + ) + elif self.is_gemma3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + layer_kwargs = {"num_of_layers": attention_pattern_length} + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + RemattedGemma3Block = gemma3.Gemma3ScannableBlock + if scan_length > 0: + self.layers = self._create_scanned_layers( + RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + self.layers_remainder = RemattedGemma3Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) + else: + layer_cls = decoder_block_classes[0] + num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) layer_kwargs = {} - if config.decoder_block == DecoderBlockType.GEMMA3: - layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} - elif config.decoder_block == DecoderBlockType.LLAMA4: + if config.decoder_block == DecoderBlockType.LLAMA4: layer_kwargs = { - "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), - "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: - layer_kwargs = {"layer_idx": lyr} - elif config.decoder_block == DecoderBlockType.GPT_OSS: - layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - elif config.decoder_block == DecoderBlockType.OLMO3: - layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + self.layers = self._create_scanned_layers( + layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + else: + if self.is_deepseek: + dense_cls, moe_cls = decoder_block_classes + self.num_dense_layers = config.first_num_dense_layers + for i in range(self.num_dense_layers): + self._create_and_register_layer(dense_cls, rngs, "dense_layers", i) + self.num_moe_layers = config.num_decoder_layers - config.first_num_dense_layers + for i in range(self.num_moe_layers): + self._create_and_register_layer(moe_cls, rngs, "moe_layers", i) + else: + layer_cls = decoder_block_classes[0] + self.num_layers = config.num_decoder_layers + for lyr in range(self.num_layers): + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), + } + elif config.decoder_block == DecoderBlockType.QWEN3_NEXT: + layer_kwargs = {"layer_idx": lyr} + elif config.decoder_block == DecoderBlockType.GPT_OSS: + layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.OLMO3: + layer_kwargs = {"attention_type": olmo3.get_attention_type(layer_id=lyr)} + self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + + def _get_pipeline_stage_module(self, decoder_blocks, rngs): + """Retrieves the wrapper module formatted for single pipeline stage execution.""" + cfg = self.config + base_stage_cls = decoder_blocks[1] if self.is_deepseek else decoder_blocks[0] - self._create_and_register_layer(layer_cls, rngs, "layers", lyr, **layer_kwargs) + if cfg.num_layers_per_pipeline_stage == 1: + return self._create_single_layer(base_stage_cls, rngs) + elif cfg.scan_layers_per_stage: + return NNXScannedPipelineStage( + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs + ) + return NNXSequentialPipelineStage( + base_stage_cls, cfg.num_layers_per_pipeline_stage, cfg, self.mesh, self.quant, self.model_mode, rngs=rngs + ) def _create_and_register_layer(self, layer_cls, rngs, base_name, i, **layer_kwargs): attr_name = f"{base_name}_{i}" layer = self._create_single_layer(layer_cls, rngs, **layer_kwargs) setattr(self, attr_name, layer) - self.layers.append(layer) def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): """Helper to create a single layer (Linen or NNX).""" @@ -386,38 +532,35 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): ) return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) - def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): + def _create_scanned_layers( + self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs + ): """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" def create_layer_fn(rng): - layer = decoder_layer_class( + return decoder_layer_class( config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs ) - return layer - # Workaround for Deepseek MTP test failure. # TODO: Handle this properly. try: forked_rngs = rngs.fork(split=length) - except: # pylint: disable=bare-except - pass + forked_rngs = rngs out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) layers_vmapped = nnx.vmap( create_layer_fn, in_axes=0, out_axes=out_axes, - axis_name="layers", - transform_metadata={nnx.PARTITION_NAME: "layers"}, + axis_name=metadata_axis_name, + transform_metadata={nnx.PARTITION_NAME: metadata_axis_name}, )(forked_rngs) - return layers_vmapped def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" - graphdef, state = nnx.split(layer) def pure_layer_fn(state_in, y_in): @@ -428,7 +571,6 @@ def pure_layer_fn(state_in, y_in): checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) out, new_state = checkpointed_fn(state, y) nnx.update(layer, new_state) - return out def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): @@ -448,41 +590,181 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs) sig = inspect.signature(layer_cls.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars - if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) - new_current_state = nnx.state(layer) - return new_carry, new_current_state + return new_carry, nnx.state(layer) layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) - scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) + # Only move the axis back on the params, NOT the mutables! + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) + + final_state = nnx.State.merge(params, scanned_state) + nnx.update(layers, final_state) + return final_carry, layers + + def _apply_interleaved_scanned_layers( + self, layers, y, layer_args, layer_kwargs, start_idx, end_idx, engram_indices, decoder_input_tokens + ): + """Applies a mix of scanned standard layers and unscanned Engram layers efficiently using native NNX state slicing.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + graphdef, params, mutables = nnx.split(layers, nnx.Param, ...) + + scan_axis = self.config.param_scan_axis + if scan_axis != 0: + max_logging.log(f"nnx_decoders: Moving param scan_axis from {scan_axis} to 0 for interleaved scan.") + params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) + + def get_chunk(pytree, start, end): + return jax.tree.map(lambda x: x[start:end], pytree) + + updated_mutables_chunks = [] + current_idx = start_idx + + while current_idx < end_idx: + if current_idx in engram_indices: + # Single engram layer execution + eng_params = get_chunk(params, current_idx, current_idx + 1) + eng_mutables = get_chunk(mutables, current_idx, current_idx + 1) + + # Squeeze the vmapped 'layers' dimension out for isolated execution + eng_params = jax.tree.map(lambda x: jnp.squeeze(x, axis=0), eng_params) + eng_mutables = jax.tree.map(lambda x: jnp.squeeze(x, axis=0), eng_mutables) + + if self.config.parameter_memory_host_offload: + eng_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), eng_params) - return final_carry, nnx.merge(graphdef, scanned_state) + layer = nnx.merge(graphdef, eng_params, eng_mutables) + kwargs_with_tokens = {**layer_kwargs, "decoder_input_tokens": decoder_input_tokens, "layer_idx": current_idx} + + sig = inspect.signature(layer.__call__) + valid_kwargs = {k: v for k, v in kwargs_with_tokens.items() if k in sig.parameters or "kwargs" in sig.parameters} + + layer_out = layer(y, *layer_args, **valid_kwargs) + y = layer_out[0] if isinstance(layer_out, tuple) else layer_out + + _, new_eng_mutables = nnx.split(layer, nnx.Param, ...) + new_eng_mutables = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), new_eng_mutables) + updated_mutables_chunks.append(new_eng_mutables) + current_idx += 1 + else: + # Scan a continuous chunk of non-engram layers + next_engrams = [l for l in engram_indices if l > current_idx] + if next_engrams: + min_next_engram = min(next_engrams) + next_boundary = min(end_idx, min_next_engram) + else: + next_boundary = end_idx + + chunk_params = get_chunk(params, current_idx, next_boundary) + chunk_mutables = get_chunk(mutables, current_idx, next_boundary) + + sig = inspect.signature(layers.__call__) + valid_kwargs = {k: v for k, v in layer_kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + + def layer_fn(carry, scanned_vars): + curr_p, curr_m = scanned_vars + if self.config.parameter_memory_host_offload: + curr_p = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), curr_p) + l = nnx.merge(graphdef, curr_p, curr_m) + l_out = l(carry, *layer_args, **valid_kwargs) + n_carry = l_out[0] if isinstance(l_out, tuple) else l_out + _, n_mut = nnx.split(l, nnx.Param, ...) + return n_carry, n_mut + + layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) + y, new_chunk_mutables = jax.lax.scan(layer_fn, y, (chunk_params, chunk_mutables)) + updated_mutables_chunks.append(new_chunk_mutables) + current_idx = next_boundary + + if updated_mutables_chunks: + final_mutables = jax.tree.map(lambda *chunks: jnp.concatenate(chunks, axis=0), *updated_mutables_chunks) + else: + final_mutables = mutables + + if scan_axis != 0: + max_logging.log(f"nnx_decoders: Moving param scan_axis 0 back to {scan_axis} for interleaved scan.") + # Only move the axis back on params! + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) + + final_state = nnx.State.merge(params, final_mutables) + nnx.update(layers, final_state) + return y, layers + + def _run_unscanned_layers_loop( + self, + base_name, + num_layers, + y, + layer_args, + layer_kwargs, + kv_caches=None, + deepstack_visual_embeds=None, + bidirectional_mask=None, + layer_idx_offset=0, + decoder_input_tokens=None, + ): + """DRY Helper for looping unscanned lists of layers while correctly handling remat, state, engrams, and KV cache.""" + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) + + def pure_layer_fn(graphdef, state_in, y_in, kv_in, dynamic_kwargs): + merged_layer = nnx.merge(graphdef, state_in) + out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **dynamic_kwargs) + return out_y, out_kv, nnx.state(merged_layer) + + checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) + + for lyr in range(num_layers): + attr_name = f"{base_name}_{lyr}" + layer = getattr(self, attr_name) + graphdef, state = nnx.split(layer) + global_lyr = layer_idx_offset + lyr + + # Prepare dynamic KV Cache unwrapping + kv_cache = None + if kv_caches is not None and self.config.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_cache = kv_caches[global_lyr] + elif kv_caches is not None and self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (global_lyr + 1) % self.config.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][global_lyr], kv_caches["value_cache"][global_lyr]) + + # Prepare dynamic Kwargs (Engrams, Layer ID) + current_kwargs = dict(layer_kwargs) + if self.config.engram_layers: + current_kwargs["decoder_input_tokens"] = decoder_input_tokens + if self.config.decoder_block == DecoderBlockType.DEEPSEEK: + current_kwargs["layer_idx"] = global_lyr + + y, returned_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache, current_kwargs) + # Re-merge the state back to the explicit attribute to prevent cross-boundary TraceContextErrors + setattr(self, attr_name, nnx.merge(graphdef, new_state)) + + # Write updated KV Cache back properly + if kv_caches is not None and returned_cache is not None: + if self.config.decoder_block != DecoderBlockType.QWEN3_NEXT: + kv_caches[global_lyr] = returned_cache + elif (global_lyr + 1) % self.config.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][global_lyr] = returned_cache[0] + kv_caches["value_cache"][global_lyr] = returned_cache[1] + + if deepstack_visual_embeds is not None and global_lyr < len(deepstack_visual_embeds): + visual_embeds = deepstack_visual_embeds[global_lyr] + if bidirectional_mask is not None and visual_embeds is not None: + y = deepstack_process(y, bidirectional_mask, visual_embeds) + + return y def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -518,7 +800,6 @@ def get_deepseek(): if cfg.decoder_block not in layer_map: raise ValueError(f"Incorrect decoder_block name {cfg.decoder_block.value=}") - return layer_map[cfg.decoder_block] def minimal_policy(self, with_context=False, with_quantization=False): @@ -573,37 +854,18 @@ def get_remat_policy(self): policy = self.minimal_policy(with_context=True, with_quantization=True) elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "context", - "out_proj", + "query_proj", "value_proj", "key_proj", "qkv_proj", "context", "out_proj" ) elif cfg.remat_policy == "save_dot_except_mlpwi": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwo", + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwo" ) elif cfg.remat_policy == "save_dot_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj" ) elif cfg.remat_policy == "save_qkv_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - ) + policy = jax.checkpoint_policies.save_only_these_names("query_proj", "value_proj", "key_proj", "qkv_proj") elif cfg.remat_policy == "qkv_proj_offloaded": policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], @@ -612,7 +874,6 @@ def get_remat_policy(self): offload_dst="pinned_host", ) elif cfg.remat_policy == "minimal_offloaded": - # offload all except context policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[ @@ -640,11 +901,10 @@ def get_remat_policy(self): policy = jax.checkpoint_policies.save_only_these_names("out_proj") else: assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None return policy def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): - """get normalization layer (return type inherits from nn.Module)""" + """Helper to retrieve the correct normalization layer class based on config, partially applied with common arguments.""" if self.config.decoder_block in ( DecoderBlockType.DEFAULT, DecoderBlockType.LLAMA2, @@ -687,28 +947,25 @@ def _apply_embedding( audio_embeddings=None, audio_masks=None, ): - """Applies token and positional embeddings to the input tokens.""" + """Applies token embedding, adds positional embedding, and merges multimodal embeddings if provided.""" cfg = self.config - y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - # Merge the image embeddings with the text embeddings for multimodal models if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in [ + if cfg.model_name in { "gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e", "qwen3-omni-30b-a3b", - ]: + }: y = mm_utils.merge_mm_embeddings( text_embeddings=y, multimodal_embeddings=image_embeddings, mask=bidirectional_mask, token_masks=image_masks, ) - # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed else: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") @@ -736,7 +993,6 @@ def _apply_embedding( def apply_output_head(self, shared_embedding, y, deterministic, model_mode): """Applies final normalization and projects hidden states to logits.""" - cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) @@ -779,115 +1035,6 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): return logits - def _build_linen_params(self, moe_stack: nnx.Module) -> dict: - """ - Bridges NNX to Linen by creating a dictionary that mimics the exact variable - structure expected by `deepseek_batchsplit.fetch_weights`. - """ - - return { - "pre_self_attention_layer_norm": { - "scale": moe_stack.pre_self_attention_layer_norm.scale, - }, - "post_self_attention_layer_norm": { - "scale": moe_stack.post_self_attention_layer_norm.scale, - }, - "self_attention": { - "wq_a": {"kernel": moe_stack.self_attention.wq_a.kernel}, - "wq_b": {"kernel": moe_stack.self_attention.wq_b.kernel}, - "q_norm": {"scale": moe_stack.self_attention.q_norm.scale}, - "wkv_a": {"kernel": moe_stack.self_attention.wkv_a.kernel}, - "wkv_b": {"kernel": moe_stack.self_attention.wkv_b.kernel}, - "kv_norm": {"scale": moe_stack.self_attention.kv_norm.scale}, - "out": {"kernel": moe_stack.self_attention.out.kernel}, - }, - "DeepSeekMoeBlock_0": { - "MoeBlock_0": { - "gate": { - "kernel": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel, - "bias": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias, - }, - "wi_0": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_0, - "wi_1": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_1, - "wo": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wo, - }, - "shared_experts": { - "wi_0": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel}, - "wi_1": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel}, - "wo": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wo.kernel}, - }, - }, - } - - def _find_next_boundary(self, current_idx, end_idx, engram_indices): - """Finds the next index boundary, either the next Engram layer index or the overall end index.""" - next_engrams = [l for l in engram_indices if l > current_idx] - if next_engrams: - return min(end_idx, *next_engrams) - return end_idx - - def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs): - """Applies a single, unscanned Engram layer by dynamically slicing the NNX state.""" - graphdef, state = nnx.split(layer_stack) - - # Slice the parameters for the current index (assuming scan axis is 0) - sliced_state = jax.tree.map(lambda x: x[current_idx], state) - single_layer = nnx.merge(graphdef, sliced_state) - - # Run the single layer - out = single_layer( - y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {}) - ) - y = out[0] if isinstance(out, tuple) else out - - # Re-merge the updated state back into the specific slice of the stack - new_single_state = nnx.state(single_layer) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0), - state, - new_single_state, - ) - nnx.update(layer_stack, updated_state) - - return y - - def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args, **kwargs): - """Applies a contiguous chunk of layers using scan over a state slice.""" - scan_length = next_boundary - current_idx - if scan_length > 0: - graphdef, state = nnx.split(layer_stack) - - # Slice the chunk state - chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state) - chunk_stack = nnx.merge(graphdef, chunk_state) - - # Apply sequentially - y, chunk_stack = self._apply_layers_sequentially( - chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {}) - ) - - # Update the original stack state - new_chunk_state = nnx.state(chunk_stack) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state - ) - nnx.update(layer_stack, updated_state) - - return y - - def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs): - """Applies a mix of scanned standard layers and unscanned Engram layers.""" - current_idx = start_idx - while current_idx < end_idx: - if current_idx in engram_indices: - y = self._apply_single_engram_layer(y, current_idx, layer_stack, *args, **kwargs) - current_idx += 1 - else: - next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices) - y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_stack, *args, **kwargs) - current_idx = next_boundary - return y - def __call__( self, shared_embedding: Any, @@ -911,8 +1058,6 @@ def __call__( cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] - policy = self.get_remat_policy() - # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( shared_embedding, @@ -940,129 +1085,252 @@ def __call__( if attention_metadata is not None: layer_kwargs["attention_metadata"] = attention_metadata + elif cfg.decoder_block == DecoderBlockType.DEEPSEEK and cfg.scan_layers: + layer_kwargs = {"previous_chunk": previous_chunk, "page_state": page_state, "slot": slot} - if cfg.scan_layers: - if self.is_deepseek: - layer_kwargs = { - "previous_chunk": previous_chunk, - "page_state": page_state, - "slot": slot, - } - - if cfg.engram_layers: - common_kwargs = { - "layer_kwargs": layer_kwargs, - "decoder_input_tokens": decoder_input_tokens, - } - - y = self._apply_interleaved_scanned_layers( - y, self.dense_layers, 0, cfg.first_num_dense_layers, cfg.engram_layers, *layer_args, **common_kwargs - ) - - y = self._apply_interleaved_scanned_layers( - y, - self.moe_layer, - 0, - (cfg.num_decoder_layers - cfg.first_num_dense_layers), - [e - cfg.first_num_dense_layers for e in cfg.engram_layers], - *layer_args, - **common_kwargs, - ) - else: - y, self.dense_layers = self._apply_layers_sequentially( - self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs - ) - - num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers - - if cfg.use_batch_split_schedule: - policy = self.get_remat_policy() + # ------------------------------------------------------------------------- + # Execution Routing (Pipeline vs Direct) + # ------------------------------------------------------------------------- + if cfg.using_pipeline_parallelism: + logical_partition_spec = self.pipeline_module.get_weight_sharding() if cfg.pipeline_fsdp_ag_once else None - mock_params = self._build_linen_params(self.moe_layer) - - y = deepseek_batchsplit.scan_batch_split_layers( - y, - mock_params, - decoder_positions, - decoder_segment_ids, - model_mode=model_mode, - mesh=self.mesh, - quant=self.quant, - cfg=cfg, - policy=policy, - ) + if self.is_deepseek: + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.scan_layers: + if cfg.engram_layers: + y, self.dense_layers = self._apply_interleaved_scanned_layers( + self.dense_layers, + y, + layer_args, + layer_kwargs, + start_idx=0, + end_idx=cfg.first_num_dense_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + if hasattr(self, "moe_layers_outside_pipeline"): + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + y, self.moe_layers_outside_pipeline = self._apply_interleaved_scanned_layers( + self.moe_layers_outside_pipeline, + y, + layer_args, + layer_kwargs, + start_idx=cfg.first_num_dense_layers, + end_idx=cfg.first_num_dense_layers + num_moe_outside, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y, self.dense_layers = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + if hasattr(self, "moe_layers_outside_pipeline"): + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + y, self.moe_layers_outside_pipeline = self._apply_layers_sequentially( + self.moe_layers_outside_pipeline, y, *layer_args, length=num_moe_outside, **layer_kwargs + ) else: - y, self.moe_layer = self._apply_layers_sequentially( - self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs + y = self._run_unscanned_layers_loop( + base_name="dense_layers", + num_layers=self.num_dense_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, ) - elif self.is_gemma3: - y = self._apply_gemma3_scanned_blocks( + if hasattr(self, "num_moe_outside_pipeline") and self.num_moe_outside_pipeline > 0: + y = self._run_unscanned_layers_loop( + base_name="moe_layers_outside_pipeline", + num_layers=self.num_moe_outside_pipeline, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.first_num_dense_layers, + decoder_input_tokens=decoder_input_tokens, + ) + + y = self.pipeline_module( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, + logical_partition_spec=logical_partition_spec, ) - else: - scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) - else: - prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) - - # Hoisted function to preserve XLA cache ID - def pure_layer_fn(graphdef, state_in, y_in, kv_in): - - if cfg.parameter_memory_host_offload: - state_in = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), state_in) - - merged_layer = nnx.merge(graphdef, state_in) - out_y, out_kv = merged_layer(y_in, *layer_args, kv_cache=kv_in, **layer_kwargs) - return out_y, out_kv, nnx.state(merged_layer) - - checkpointed_fn = jax.checkpoint(pure_layer_fn, policy=policy, prevent_cse=prevent_cse) - - for lyr, layer in enumerate(self.layers): - graphdef, state = nnx.split(layer) - kv_cache = kv_caches[lyr] if kv_caches is not None else None - - input_tokens = decoder_input_tokens if cfg.engram_layers else None - if input_tokens is not None: - layer_kwargs["decoder_input_tokens"] = input_tokens - y, kv_cache, new_state = checkpointed_fn(graphdef, state, y, kv_cache) - nnx.update(layer, new_state) + else: + # Standard Pipeline Run + y = self.pipeline_module( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + logical_partition_spec=logical_partition_spec, + ) - if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + # Remaining standard layers + if hasattr(self, "num_layers_outside_pipeline") or hasattr(self, "layers_outside_pipeline"): + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.scan_layers: + y, self.layers_outside_pipeline = self._apply_layers_sequentially( + self.layers_outside_pipeline, + y, + *layer_args, + length=len(self.layers_outside_pipeline.scanned_layers), + **layer_kwargs, + ) + else: + y = self._run_unscanned_layers_loop( + base_name="layers_outside_pipeline", + num_layers=self.num_layers_outside_pipeline, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.pipeline_parallel_layers, + decoder_input_tokens=decoder_input_tokens, + ) - if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): - visual_embeds = deepstack_visual_embeds[lyr] - if bidirectional_mask is not None and visual_embeds is not None: - y = deepstack_process(y, bidirectional_mask, visual_embeds) + else: + # Non-Pipeline Run + if cfg.scan_layers: + if self.is_deepseek: + if cfg.engram_layers: + y, self.dense_layers = self._apply_interleaved_scanned_layers( + self.dense_layers, + y, + layer_args, + layer_kwargs, + start_idx=0, + end_idx=cfg.first_num_dense_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + y, self.moe_layers = self._apply_interleaved_scanned_layers( + self.moe_layers, + y, + layer_args, + layer_kwargs, + start_idx=cfg.first_num_dense_layers, + end_idx=cfg.num_decoder_layers, + engram_indices=cfg.engram_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y, self.dense_layers = self._apply_layers_sequentially( + self.dense_layers, y, *layer_args, length=cfg.first_num_dense_layers, **layer_kwargs + ) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + + # Use raw deepseek_batchsplit logic for MoE scanned layers to minimize VRAM overhead + layer_is_initializing = self.quant is not None and len(nnx.state(self.moe_layers, "aqt")) == 0 + if cfg.use_batch_split_schedule and not layer_is_initializing: + raw_weights = nnx.to_pure_dict(nnx.state(self.moe_layers, nnx.Param)) + y = deepseek_batchsplit.scan_batch_split_layers( + y, + raw_weights, + decoder_positions, + decoder_segment_ids, + model_mode=model_mode, + mesh=self.mesh, + quant=self.quant, + cfg=cfg, + policy=self.get_remat_policy(), + ) + else: + y, self.moe_layers = self._apply_layers_sequentially( + self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs + ) + + elif self.is_gemma3: + y = self._apply_gemma3_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) + else: + y, self.layers = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=cfg.num_decoder_layers, **layer_kwargs + ) + else: + if self.is_deepseek: + y = self._run_unscanned_layers_loop( + base_name="dense_layers", + num_layers=self.num_dense_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, + ) + y = self._run_unscanned_layers_loop( + base_name="moe_layers", + num_layers=self.num_moe_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=cfg.first_num_dense_layers, + decoder_input_tokens=decoder_input_tokens, + ) + else: + y = self._run_unscanned_layers_loop( + base_name="layers", + num_layers=self.num_layers, + y=y, + layer_args=layer_args, + layer_kwargs=layer_kwargs, + kv_caches=kv_caches, + deepstack_visual_embeds=deepstack_visual_embeds, + bidirectional_mask=bidirectional_mask, + layer_idx_offset=0, + decoder_input_tokens=decoder_input_tokens, + ) assert isinstance(y, jax.Array) - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. if cfg.mhc_expansion_rate > 1: # (batch, length, mhc_expansion_rate, emb_dim) --> (batch, length, emb_dim) + hidden_state = mhc_reduce(y) else: hidden_state = y # When invoking from vLLM with RPA attention, logit computation is deferred to a later stage. if cfg.attention == "vllm_rpa": + if not cfg.logits_via_embedding and hasattr(self, "logits_dense"): + if self.quant is not None and len(nnx.state(self.logits_dense, "aqt")) == 0: + _ = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) logits = None - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) - else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) @@ -1109,10 +1377,9 @@ def pure_gemma_fn(graphdef, state_in, y_in): return out_y, nnx.state(merged_layer) checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) - graphdef, state = nnx.split(self.layers_remainder) y, new_state = checkpointed_gemma_fn(graphdef, state, y) - nnx.update(self.layers_remainder, new_state) + self.layers_remainder = nnx.merge(graphdef, new_state) return y diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 60cb7d2ac2..321c655eb3 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -26,7 +26,7 @@ from flax.core import meta from flax import linen as nn -from flax.linen.spmd import LogicallyPartitioned +from flax import nnx from maxtext.common.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode from maxtext.utils.sharding import ( @@ -37,25 +37,22 @@ logical_to_mesh, ) from maxtext.utils import pipeline_utils +from maxtext.utils import max_logging -class PipelineBase(nn.Module): - """Base module that implements shared pipelining logic across stages.""" +class PipelineSharedMixin: + """Contains pure JAX and mathematical utilities shared identically by Linen and NNX.""" - config: Config - layers: nn.Module - mesh: Mesh - remat_policy: Any = None - - def setup(self): + def _setup_pipeline_attributes(self): """Initializes the configuration, calculating num_stages, delay, axes, and partition specs.""" self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches - microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages - self.microbatches_per_stage = microbatches_per_stage + self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.use_circ_storage = self.need_circ_storage() + self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: self.batch_axis_name = "activation_batch_no_exp" self.seq_len_axis_name = "activation_length" @@ -63,8 +60,6 @@ def setup(self): self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" - self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None - self.stages_in_logical = ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed") self.stages_in_spec = logical_to_mesh_axes(self.stages_in_logical, self.mesh, rules=self.config.logical_axis_rules) self.stages_in_sharding = ( @@ -176,8 +171,7 @@ def select_state_or_input(first_stage_in, shift): # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) - stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) - return stages_in + return self._maybe_shard_with_logical(stages_in, self.stages_in_logical) def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and @@ -192,140 +186,10 @@ def get_pipeline_remat_policy(self): """Returns the pipeline remat policy for this pipeline.""" if self.config.remat_policy == "custom": return self.remat_policy - save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") if self.remat_policy is not None: - remat_policy = jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) - else: - remat_policy = save_input_policy - return remat_policy - - def get_weight_sharding(self, *init_args): - """get weight sharding function for this pipeline.""" - key = jax.random.PRNGKey(0) - keys = {"params": key, "dropout": key, "aqt": key} - weights = self.init(keys, *init_args) - - def get_partition_spec(pytree): - def _is_leaf(x): - return isinstance(x, nn.spmd.LogicallyPartitioned) - - def get_partition_spec_leaf(leaf): - return leaf.get_partition_spec() - - return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) - - partition_spec_with_extra_layer = get_partition_spec(weights) - logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} - return logical_partition_spec - - def get_vmap_func_for_init(self): - """This vmap func is used to initialize the weights only on init.""" - - def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): - return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, None, None), - spmd_axis_name=self.spmd_axis_name, - variable_axes={"params": 0, "_overwrite_with_gradient": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return vmap_func - - def get_main_vmap_func_for_iterations(self): - """ - Returns main stage function vmapped by number of stages. - This becomes a vmap over a single layer instance if body_instance is a single layer, - else a set of layers if body_instance is a set of layers. - """ - - def func_to_vmap( - body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode - ): - weights = meta.remove_axis( - weights, - 0, - { - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) - - vmap_func = nn.vmap( - func_to_vmap, - in_axes=(0, 0, 0, 0, None, None), - spmd_axis_name=self.spmd_axis_name, - variable_axes={"params": 0}, - split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "layers", - "sub_weight_split_dims_mapping": (None), - "is_initializing": self.is_initializing(), - "x_times": self.num_stages, - }, - ) - return vmap_func - - def _run_weight_initialization( - self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode - ): - """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" - vmap_func = self.get_vmap_func_for_init() - - if self.config.num_pipeline_repeats > 1: - vmap_func = nn.vmap( - vmap_func, - in_axes=(0, segment_idx, position_idx, None, None), - variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0}, - split_rngs={"params": True, "dropout": self.config.enable_dropout}, - metadata_params={ - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": True, - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - }, - ) - example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) - example_segmentation = ( - jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) - if example_segmentation is not None - else None - ) - example_position = ( - jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) - if example_position is not None - else None - ) - - example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) - stage_outputs = vmap_func( - self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode - ) - if self.config.scan_layers: - stage_outputs = stage_outputs[0] - if self.config.num_pipeline_repeats > 1: - stage_outputs = stage_outputs[0] - broadcasted_stage_outpus = jax.lax.broadcast( - stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] - ) - - return jnp.reshape( - broadcasted_stage_outpus, - [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], - out_sharding=self.output_sharding, - ) + return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + return save_input_policy @staticmethod def _remove_fsdp_from_physical_partition_spec(pps): @@ -352,10 +216,6 @@ def _remove_fsdp_from_physical_partition_spec(pps): return P(*new_spec) return pps - -class Pipeline(PipelineBase): - """Original Pipeline implementation.""" - def init_states(self, inputs): """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] @@ -388,6 +248,7 @@ def init_states(self, inputs): state_io = jnp.reshape( inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding ) + # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical) @@ -409,7 +270,7 @@ def init_states(self, inputs): circ_storage = None circ_storage_mover = None - init_loop_state = { + return { "state_io": state_io, "shift": shift, "circ_storage": circ_storage, @@ -417,7 +278,6 @@ def init_states(self, inputs): "loop_iteration": 0, "prev_outputs": prev_outputs, } - return init_loop_state def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at @@ -471,10 +331,9 @@ def _gather_one(x, repeat_id): stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - stage_weights = self.shard_dim_by_stages( + return self.shard_dim_by_stages( stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True ) - return stage_weights def vmap_gather(self, xs, ids, ids_dim): """Use vmap to implement a stage-wise sharded gather. @@ -491,9 +350,11 @@ def vmap_gather(self, xs, ids, ids_dim): The per-stage gathered values. The shape is xs.shape but with ids_dim size replaced with [num_stages]. """ + xs = jnp.asarray(xs) + ndim = xs.ndim def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) + idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) replicated_sharding = NamedSharding(self.mesh, P()) return x.at[idx].get(out_sharding=replicated_sharding) @@ -501,17 +362,11 @@ def _gather_one(x, i): outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) - def get_new_loop_state(self, output, loop_state): - """ - Update the various buffers given the output of the most recent iteration - * state_io: rotates left/up by 1 (the whole created in the last slot is filled with the most recent pipeline output) - * Pushing inputs up from top of state_io into first stage of shift - * Pulling outputs up from last stage of shift into bottom of state_io - * shift: rotate output (or prev_outputs if using delay) right/down by 1 - we imagine the pipeline moves to - right/down - * circ_storage: pushes circ_storage_mover (the output of the previous iteration) into rotating index of circ_storage - * circ_storage_mover: assigned to rotated output and pushed into circ_storage on the next iteration - * prev_outputs: is set to the current output + def advance_circular_buffers(self, output, loop_state): + """Rotates pipeline activations to the next physical device stage. + + Uses `jax.lax.ppermute` to perform cross-device ring communication, shifting + the forward activations (`state_io` and `shift`) from stage $i$ to stage $i+1$. """ old_state_io = loop_state["state_io"] old_circ_storage = loop_state["circ_storage"] @@ -521,11 +376,9 @@ def get_new_loop_state(self, output, loop_state): @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _rotate_right(arr): - # we use +1 for right shifting stage_size = jax.lax.axis_size("stage") perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return arr + return jax.lax.ppermute(arr, axis_name="stage", perm=perm) @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) def _shift_right(arr): @@ -557,8 +410,7 @@ def _update_shift(output_in): # circ_storage_mover still points to the output of PREVIOUS iteration, which should aid in allowing overlapped # compute/async transfers def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): - rotated = _rotate_right(circ_storage_mover_in) - rotated = jnp.expand_dims(rotated, 1) + rotated = jnp.expand_dims(_rotate_right(circ_storage_mover_in), 1) # We rotate the pushing index into circ storage, and ensure that microbatch 0 lands in index 0 offset = ( loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 @@ -601,7 +453,7 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) - new_loop_state = { + return { "state_io": new_state, "shift": new_shift, "circ_storage": new_circ_storage, @@ -609,7 +461,6 @@ def _update_state_io(state_in, stream_slice, output, stream_buf_idx): "loop_iteration": loop_iteration + 1, "prev_outputs": new_prev_outputs, } - return new_loop_state def permute_output_micro_per_stage_dim(self, output): """ @@ -625,8 +476,18 @@ def permute_output_micro_per_stage_dim(self, output): # state_io - it will land on a different index of state_io depending on the number of iterations. microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage permutation = (np.arange(self.microbatches_per_stage) + microbatch_0_idx) % self.microbatches_per_stage - output = output[:, permutation] - return output + return output[:, permutation] + + def realign_output_microbatches(self, output): + """Reorders the output tensor to reverse the circular shifts applied during execution. + + Because the pipeline operates circularly, the output microbatches are shifted + out of order by the time the final stage is completed. This rolls them back + into their original sequential layout. + """ + microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + output = jnp.roll(output, shift=-microbatch_0_idx, axis=1) + return self._maybe_shard_with_logical(output, self.state_io_logical) def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): """ @@ -639,105 +500,277 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_p return self.get_current_repeat_from_stages( pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec ) - else: - return pipeline_weights + return pipeline_weights - def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): - """Fetches the weights for the current repeat from the stages.""" - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, - # only one circular entry per stage. - weights = meta.remove_axis(weights, 0, circular_metadata_params) - weights = self._remove_logically_partition(weights) + def all_gather_over_fsdp(self, variables, logical_partition_spec): + if logical_partition_spec is None: + return variables + + def _gather_leaf(var, spec): + if spec is None: + return var + physical = logical_to_mesh_axes(spec, self.mesh, rules=self.config.logical_axis_rules) + no_fsdp = self._remove_fsdp_from_physical_partition_spec(physical) + sharding = NamedSharding(self.mesh, no_fsdp) + if isinstance(var, nnx.Variable): + var.value = self._maybe_shard_with_name(var.value, sharding) + return var + return self._maybe_shard_with_name(var, sharding) + + # nnx.Variable and PartitionSpec are JAX pytree nodes — treat them as leaves + # so the two trees align at the dict level. None must also be a leaf to avoid + # being treated as an empty container (0 children) vs the Variable's 1 child. + is_leaf = lambda x: isinstance(x, (nnx.Variable, P)) or x is None + return jax.tree.map(_gather_leaf, variables, logical_partition_spec, is_leaf=is_leaf) + + def get_logical_spec_repeats_removed(self, full_logical): + """Returns a new logical spec with 'circular_repeats' removed.""" + if full_logical is None or self.config.num_pipeline_repeats == 1: + return full_logical - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) + def _remove_from_spec(spec): + if not isinstance(spec, P): + return spec + if spec and (spec[0] == "circular_repeats" or spec[0] is None): + return jax.sharding.PartitionSpec(*spec[1:]) + return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights + return jax.tree.map(_remove_from_spec, full_logical, is_leaf=lambda x: isinstance(x, P)) - def run_one_iteration( - self, - loop_state, - pipeline_weights, - positions, - segment_ids, - deterministic, - model_mode, - decoder_layer_instance, - logical_partition_spec=None, - ): - """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, - and update the loop state. - Args: - loop_state: Dictionary containing the current state of the pipeline (state_io, shift, etc.) - positions: Positional encodings. - segment_ids: Segment IDs for packed sequences. - deterministic: Boolean indicating if execution should be deterministic (e.g. for dropout). - model_mode: Current model mode (train/predict). - logical_partition_spec: Logical partition specification for weights. - """ - state_io = loop_state["state_io"] - shift = loop_state["shift"] - circ_storage = loop_state["circ_storage"] - loop_iteration = loop_state["loop_iteration"] +class PipelineBaseLinen(nn.Module, PipelineSharedMixin): + """Base module that implements shared pipelining logic across stages for Linen.""" - microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) - physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + config: Config + layers: nn.Module + mesh: Mesh + remat_policy: Any = None - stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) - stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") - stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None - stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + def setup(self): + self._setup_pipeline_attributes() - vmap_func = self.get_main_vmap_func_for_iterations() + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + key = jax.random.PRNGKey(0) + keys = {"params": key, "dropout": key, "aqt": key} + weights = self.init(keys, *init_args) - if self.config.num_pipeline_repeats > 1: - _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + def get_partition_spec(pytree): + def _is_leaf(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) - def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): - circular_metadata_params = { - nn.PARTITION_NAME: "circular_repeats", - "sub_weight_split_dims_mapping": (None,), - "is_initializing": self.is_initializing(), - "x_times": self.config.num_pipeline_repeats, - "optimizer_dims_mapping": None, - } - weights = meta.remove_axis(weights, 0, circular_metadata_params) - weights = self._remove_logically_partition(weights) + def get_partition_spec_leaf(leaf): + return leaf.get_partition_spec() - def gather_weights_for_stages_in(w, spec=None): - return self.vmap_parallel_gather( - w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec - ) + return jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf) - if physical_partition_spec is None: - weights = jax.tree.map(gather_weights_for_stages_in, weights) - else: - weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) - return weights + partition_spec_with_extra_layer = get_partition_spec(weights) + return {"params": partition_spec_with_extra_layer["params"]["layers"]} - prepare_vars_for_main_vmap_partial = functools.partial( - prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec - ) - vmap_func = nn.map_variables( - vmap_func, - mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], - mutable=True, - trans_in_fn=prepare_vars_for_main_vmap_partial, + def get_vmap_func_for_init(self): + """This vmap func is used to initialize the weights only on init.""" + + def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + return body_instance(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + return nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0, "_overwrite_with_gradient": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + + def get_main_vmap_func_for_iterations(self): + """ + Returns main stage function vmapped by number of stages. + This becomes a vmap over a single layer instance if body_instance is a single layer, + else a set of layers if body_instance is a set of layers. + """ + + def func_to_vmap( + body_instance, weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode + ): + weights = meta.remove_axis( + weights, + 0, + { + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + return body_instance.apply(weights, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + + return nn.vmap( + func_to_vmap, + in_axes=(0, 0, 0, 0, None, None), + spmd_axis_name=self.spmd_axis_name, + variable_axes={"params": 0}, + split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "layers", + "sub_weight_split_dims_mapping": (None), + "is_initializing": self.is_initializing(), + "x_times": self.num_stages, + }, + ) + + def _run_weight_initialization( + self, example_inputs, example_segmentation, example_position, segment_idx, position_idx, deterministic, model_mode + ): + """Runs the initialization sequence mapping layers appropriately based on pipeline settings.""" + vmap_func = self.get_vmap_func_for_init() + + if self.config.num_pipeline_repeats > 1: + vmap_func = nn.vmap( + vmap_func, + in_axes=(0, segment_idx, position_idx, None, None), + variable_axes={"params": 0, "_overwrite_with_gradient": 0, "non_trainable": 0, "hyper_params": 0}, + split_rngs={"params": True, "dropout": self.config.enable_dropout}, + metadata_params={ + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": True, + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + }, + ) + example_inputs = jax.lax.broadcast(example_inputs, [self.config.num_pipeline_repeats]) + example_segmentation = ( + jax.lax.broadcast(example_segmentation, [self.config.num_pipeline_repeats]) + if example_segmentation is not None + else None + ) + example_position = ( + jax.lax.broadcast(example_position, [self.config.num_pipeline_repeats]) + if example_position is not None + else None + ) + + example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) + stage_outputs = vmap_func( + self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode + ) + if self.config.scan_layers: + stage_outputs = stage_outputs[0] + if self.config.num_pipeline_repeats > 1: + stage_outputs = stage_outputs[0] + broadcasted_stage_outpus = jax.lax.broadcast( + stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] + ) + + return jnp.reshape( + broadcasted_stage_outpus, + [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + out_sharding=self.output_sharding, + ) + + +class Pipeline(PipelineBaseLinen): + """Original Pipeline implementation.""" + + def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): + """Fetches the weights for the current repeat from the stages.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + + # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, + # only one circular entry per stage. + weights = meta.remove_axis(weights, 0, circular_metadata_params) + weights = pipeline_utils.remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + return jax.tree.map(gather_weights_for_stages_in, weights) + return jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + + def run_one_iteration( + self, + loop_state, + pipeline_weights, + positions, + segment_ids, + deterministic, + model_mode, + decoder_layer_instance, + logical_partition_spec=None, + ): + """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, + and update the loop state. + + Args: + loop_state: Dictionary containing the current state of the pipeline (state_io, shift, etc.) + positions: Positional encodings. + segment_ids: Segment IDs for packed sequences. + deterministic: Boolean indicating if execution should be deterministic (e.g. for dropout). + model_mode: Current model mode (train/predict). + logical_partition_spec: Logical partition specification for weights. + """ + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func_for_iterations() + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): + circular_metadata_params = { + nn.PARTITION_NAME: "circular_repeats", + "sub_weight_split_dims_mapping": (None,), + "is_initializing": self.is_initializing(), + "x_times": self.config.num_pipeline_repeats, + "optimizer_dims_mapping": None, + } + weights = meta.remove_axis(weights, 0, circular_metadata_params) + weights = pipeline_utils.remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + return jax.tree.map(gather_weights_for_stages_in, weights) + return jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) + + prepare_vars_for_main_vmap_partial = functools.partial( + prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec + ) + vmap_func = nn.map_variables( + vmap_func, + mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], + mutable=True, + trans_in_fn=prepare_vars_for_main_vmap_partial, ) stage_weights = self.get_current_stage_weights( @@ -755,42 +788,7 @@ def gather_weights_for_stages_in(w, spec=None): if self.config.scan_layers: stages_output = stages_output[0] - new_state = self.get_new_loop_state(stages_output, loop_state) - return new_state - - @staticmethod - def get_logical_spec_repeats_removed(full_logical): - """Returns a new logical spec with 'circular_repeats' removed.""" - if full_logical is None: - return None - - def _remove_from_spec(spec): - return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) - - return jax.tree.map(_remove_from_spec, full_logical) - - @staticmethod - def _remove_logically_partition(weights): - """Removes LogicallyPartitioned wrappers from the variables.""" - - def _remove_logically_partition_leaf(v): - return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v - - return jax.tree.map(_remove_logically_partition_leaf, weights, is_leaf=lambda v: isinstance(v, LogicallyPartitioned)) - - def all_gather_over_fsdp(self, variables, logical_partition_spec): - """Gathers FSDP partitioned variables to reconstruct them fully.""" - physical_partition_spec = logical_to_mesh( - logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - physical_partition_spec_no_fsdp = jax.tree.map( - self._remove_fsdp_from_physical_partition_spec, physical_partition_spec - ) - return jax.tree.map( - lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), - variables, - physical_partition_spec_no_fsdp, - ) + return self.advance_circular_buffers(stages_output, loop_state) @nn.compact def __call__( @@ -815,12 +813,12 @@ def __call__( ), out_sharding=self.input_sharding, ) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: - positions = self._maybe_shard_with_name(positions, ag_sharding) - positions = positions.reshape( + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) @@ -830,8 +828,7 @@ def __call__( position_idx = None if segment_ids is not None: - segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding) - segment_ids = segment_ids.reshape( + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) @@ -861,7 +858,7 @@ def __call__( ) if self.config.pipeline_fsdp_ag_once: - variables = self._remove_logically_partition(self.layers.variables) + variables = pipeline_utils.remove_logically_partition(self.layers.variables) all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) else: all_pipeline_weights = self.layers.variables @@ -902,6 +899,7 @@ def run_iteration_scannable(model, loop_state, xs): variable_carry.append("non_trainable") else: variable_broadcast.append("non_trainable") + run_all_iterations_scanned = nn.scan( run_iteration_scannable, variable_axes={"summaries": 0, "aux_loss": 0, "intermediates": 0, "hyper_params": 0}, @@ -920,15 +918,14 @@ def run_iteration_scannable(model, loop_state, xs): # the input final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] - final_output = jnp.reshape( + return jnp.reshape( final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), out_sharding=self.output_sharding, ) - return final_output -class CircularPipeline(PipelineBase): +class CircularPipeline(PipelineBaseLinen): """Implements an circular pipeline schedule with asynchronous weight prefetching. Circular pipelining reduces the pipeline "bubble" by interleaving multiple pipeline @@ -945,26 +942,7 @@ def init_states(self, inputs): (`state_io` and `shift`) and allocates the empty Buffer Sliding Window (BSW) that will hold the gathered FSDP weights. """ - shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) - - if self.config.pipeline_delay_activation_forwarding: - prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical) - else: - prev_outputs = None - - state_io = jnp.reshape( - inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding - ) - state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical) - - if self.use_circ_storage: - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding) - circ_storage_mover = shift - else: - circ_storage = None - circ_storage_mover = None + init_loop_state = super().init_states(inputs) def _init_empty_bsw_buffers(variables): # BSW requires two buffers (current and next) for the sliding window @@ -979,143 +957,27 @@ def _init_empty_bsw_buffers(variables): variables = pipeline_utils.remove_logically_partition(self.layers.variables) bsw = _init_empty_bsw_buffers(variables) - init_loop_state = { - "state_io": state_io, - "shift": shift, - "circ_storage": circ_storage, - "circ_storage_mover": circ_storage_mover, - "loop_iteration": 0, - "prev_outputs": prev_outputs, - } return init_loop_state, bsw - def gather_weights_across_stages_vmap(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): - """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" + def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): + """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.""" + xs = jnp.asarray(xs) # Safe casting for non-JAX arrays + ndim = xs.ndim - def _gather_single_repeat(x, repeat_id): - return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) - - gathered_weights_stage_dim = 0 - stage_weights = jax.vmap( - _gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim - )(weights, repeat_ids) - return stage_weights - - def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): - """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.""" - - def _gather_one(x, i): - idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + def _gather_one(x, i): + idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) + replicated_sharding = NamedSharding(self.mesh, P()) + return x.at[idx].get(out_sharding=replicated_sharding) return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - def advance_circular_buffers(self, output, loop_state): - """Rotates pipeline activations to the next physical device stage. - - Uses `jax.lax.ppermute` to perform cross-device ring communication, shifting - the forward activations (`state_io` and `shift`) from stage $i$ to stage $i+1$. - """ - old_state_io = loop_state["state_io"] - old_circ_storage = loop_state["circ_storage"] - old_circ_storage_mover = loop_state["circ_storage_mover"] - loop_iteration = loop_state["loop_iteration"] - - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) - def _rotate_right(arr): - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - @jax.shard_map(mesh=self.mesh, in_specs=self.stages_in_spec, out_specs=self.stages_in_spec, check_vma=True) - def _shift_right(arr): - stage_idx = jax.lax.axis_index("stage") - stage_size = jax.lax.axis_size("stage") - perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] - arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) - return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) - - def _update_shift(output_in): - if self.config.num_pipeline_repeats == 1 or self.use_circ_storage: - return _shift_right(output_in) - else: - return _rotate_right(output_in) - - new_shift = _update_shift(output) - new_prev_outputs = None - - if self.use_circ_storage: - - def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): - rotated = _rotate_right(circ_storage_mover_in) - rotated = jnp.expand_dims(rotated, 1) - offset = ( - loop_iteration - self.iterations_to_complete_first_microbatch_one_repeat() - 1 - ) % self.config.num_pipeline_microbatches - return jax.lax.dynamic_update_slice_in_dim(circ_storage_in, rotated, offset, axis=1) - - new_circ_storage = _rotate_right_and_update(old_circ_storage_mover, old_circ_storage) - new_circ_storage_mover = output - else: - new_circ_storage = None - new_circ_storage_mover = None - - stream_buf_idx = loop_iteration % self.microbatches_per_stage - stream_slice = old_state_io[:, stream_buf_idx] - - def _rotate_left(arr, stage_size): - perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] - return jax.lax.ppermute(arr, axis_name="stage", perm=perm) - - def _shift_left(arr, stage_size, output): - stage_idx = jax.lax.axis_index("stage") - arr = _rotate_left(arr, stage_size) - return jnp.where(stage_idx == stage_size - 1, output, arr) - - @jax.shard_map( - mesh=self.mesh, - in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), - out_specs=self.state_io_spec, - check_vma=True, - ) - def _update_state_io(state_in, stream_slice, output, stream_buf_idx): - stage_size = jax.lax.axis_size("stage") - stream_slice = _shift_left(stream_slice, stage_size, output) - stream_slice = jnp.expand_dims(stream_slice, 1) - return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - - new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) - new_loop_state = { - "state_io": new_state, - "shift": new_shift, - "circ_storage": new_circ_storage, - "circ_storage_mover": new_circ_storage_mover, - "loop_iteration": loop_iteration + 1, - "prev_outputs": new_prev_outputs, - } - return new_loop_state - - def realign_output_microbatches(self, output): - """Reorders the output tensor to reverse the circular shifts applied during execution. - - Because the pipeline operates circularly, the output microbatches are shifted - out of order by the time the final stage is completed. This rolls them back - into their original sequential layout. - """ - microbatch_0_idx = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage - output = jnp.roll(output, shift=-microbatch_0_idx, axis=1) - output = self._maybe_shard_with_logical(output, self.state_io_logical) - return output - def fetch_active_stage_weights(self, bsw, loop_iteration, physical_partition_spec=None, is_initializing=None): """The module fetches the actively prefetched weights from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers. """ - pipeline_weights = self.get_current_weights_from_bsw( + return self.get_current_weights_from_bsw( bsw, loop_iteration, physical_partition_spec=physical_partition_spec, is_initializing=is_initializing ) - return pipeline_weights def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec, is_initializing=None): """Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer.""" @@ -1139,8 +1001,19 @@ def select_weights_from_bsw(bsw, repeat_id): "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } - weights = meta.remove_axis(weights, 0, circular_metadata_params) - return weights + return meta.remove_axis(weights, 0, circular_metadata_params) + + def gather_weights_across_stages_vmap(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" + + def _gather_single_repeat(x, repeat_id): + return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) + + gathered_weights_stage_dim = 0 + stage_weights = jax.vmap( + _gather_single_repeat, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim + )(weights, repeat_ids) + return stage_weights def from_all_variables_to_repeat_weights(self, weights, loop_iteration): """Gathers weights corresponding to the repeat IDs for current iteration.""" @@ -1161,8 +1034,7 @@ def gather_weights_for_stages_in(w): "x_times": self.config.num_pipeline_repeats, "optimizer_dims_mapping": None, } - repeat_weights = meta.remove_axis(weights, 0, circular_metadata_params) - return repeat_weights + return meta.remove_axis(weights, 0, circular_metadata_params) def from_repeat_weights_to_bsw( self, @@ -1175,18 +1047,13 @@ def from_repeat_weights_to_bsw( axes_to_remove = ["fsdp", "fsdp_transpose"] bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove) - def _from_repeat_weights_to_bsw_shardmap( - repeat_weights, - physical_partition_spec, - axes_to_gather, - ): + def _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather): repeat_weights_pps = jax.tree.map(lambda p: P(*p[1:]), physical_partition_spec) # Dynamically gather the index pytrees for all specified axes axis_indices_dict = { axis: pipeline_utils.get_mesh_axis_dim_indices(physical_partition_spec, axis) for axis in axes_to_gather } - axis_names = list(axis_indices_dict.keys()) axis_pytrees = list(axis_indices_dict.values()) @@ -1205,7 +1072,6 @@ def should_skip_gather(axis_name, path_keys): check_vma=False, ) def _shard_map_gather_weights(sharded_weights, indices_pytrees_list): - # Renamed to clarify we are gathering a single tensor iteratively along requested axes def _gather_tensor_along_axes(path, x, *indices): path_keys = [getattr(p, "key", str(p)) for p in path] @@ -1220,9 +1086,7 @@ def _gather_tensor_along_axes(path, x, *indices): return _shard_map_gather_weights(repeat_weights, axis_pytrees) - def _from_repeat_weights_to_bsw_hint( - repeat_weights, - ): + def _from_repeat_weights_to_bsw_hint(repeat_weights): def _apply_sharding_hint(weight, pspec): sharding_name = NamedSharding(self.mesh, pspec) return maybe_shard_with_name( @@ -1290,8 +1154,7 @@ def run_one_iteration( if self.config.scan_layers: stages_output = stages_output[0] - new_state = self.advance_circular_buffers(stages_output, loop_state) - return new_state + return self.advance_circular_buffers(stages_output, loop_state) @nn.compact def __call__( @@ -1317,8 +1180,7 @@ def __call__( ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: - positions = self._maybe_shard_with_name(positions, ag_sharding) - positions = positions.reshape( + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_position = jax.lax.broadcast(positions[0], [self.num_stages]) @@ -1328,8 +1190,7 @@ def __call__( position_idx = None if segment_ids is not None: - segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding) - segment_ids = segment_ids.reshape( + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) example_segmentation = jax.lax.broadcast(segment_ids[0], [self.num_stages]) @@ -1354,19 +1215,17 @@ def __call__( logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec) def run_iteration_scannable(model, loop_state, bsw, weights): - return ( - model.run_one_iteration( - loop_state, - bsw, - weights, - positions, - segment_ids, - deterministic, - model_mode, - logical_partition_spec=logical_partition_spec, - ), - None, + new_loop_state = model.run_one_iteration( + loop_state, + bsw, + weights, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec=logical_partition_spec, ) + return (new_loop_state, bsw, weights), None if self.config.set_remat_policy_on_pipeline_iterations: run_iteration_scannable = nn.remat( @@ -1374,7 +1233,6 @@ def run_iteration_scannable(model, loop_state, bsw, weights): prevent_cse=not self.config.scan_pipeline_iterations, policy=self.get_pipeline_remat_policy(), ) - # base scannable function used twice for real and bubble runs base_scannable = functools.partial( pipeline_utils.create_rematerialized_pipeline_stage, @@ -1405,18 +1263,636 @@ def run_iteration_scannable(model, loop_state, bsw, weights): (loop_state, bsw, weights), _ = run_iteration_scannable(self, loop_state, bsw, weights) final_output = self.realign_output_microbatches(loop_state["state_io"]) - final_output = jnp.reshape( + return jnp.reshape( final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), out_sharding=self.output_sharding, ) - return final_output -def create_pipeline(config: Config, layers: nn.Module, mesh: Mesh, remat_policy: Any = None) -> PipelineBase: +def create_pipeline( + config: Config, layers: nn.Module, mesh: Mesh, remat_policy: Any = None +) -> Pipeline | CircularPipeline: """Factory function to instantiate the correct Pipeline module based on config.""" - if config.pipeline_fsdp_ag_per_repeat: return CircularPipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy) - return Pipeline(config=config, layers=layers, mesh=mesh, remat_policy=remat_policy) + + +class NNXPipelineBase(nnx.Module, PipelineSharedMixin): + """Base module that implements shared pipelining logic across stages for NNX.""" + + def __init__( + self, + config: Config, + stage_factory: Any, + mesh: Mesh, + remat_policy: Any = None, + *, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.remat_policy = remat_policy + self._setup_pipeline_attributes() + + def build_batched_rngs(shape): + max_logging.log(f"Building batched RNGs with shape {shape}...") + kwargs = {} + rng_state = nnx.state(rngs, nnx.RngState) + leaves, _ = jax.tree_util.tree_flatten_with_path(rng_state) + for path, key in leaves: + stream_name = getattr(path[0], "key", str(path[0])) + if not jax.dtypes.issubdtype(key.dtype, jax.dtypes.prng_key): + key = jax.random.key(key) + num_splits = int(np.prod(shape)) + flat_keys = jax.random.split(key, num_splits) + kwargs[stream_name] = flat_keys.reshape(shape + key.shape) + return nnx.Rngs(**kwargs) + + def create_stage_fn(r): + stage = stage_factory(r) + # Split into (GraphDef, Param State, Rest of State) + return nnx.split(stage, nnx.Param, ...) + + vmap_stages = nnx.vmap( + create_stage_fn, + in_axes=0, + out_axes=(None, 0, 0), + axis_name=self.spmd_axis_name, + transform_metadata={nnx.PARTITION_NAME: "layers"}, + ) + + if self.config.num_pipeline_repeats > 1: + vmap_repeats = nnx.vmap( + vmap_stages, + in_axes=0, + out_axes=(None, 0, 0), + transform_metadata={nnx.PARTITION_NAME: "circular_repeats"}, + ) + batched_rngs = build_batched_rngs((self.config.num_pipeline_repeats, self.num_stages)) + graphdef, params, rest = vmap_repeats(batched_rngs) + else: + batched_rngs = build_batched_rngs((self.num_stages,)) + graphdef, params, rest = vmap_stages(batched_rngs) + + # Merge the batched states back into the module + self.layers = nnx.merge(graphdef, params, rest) + + def get_weight_sharding(self, *init_args): + """get weight sharding function for this pipeline.""" + state = nnx.state(self.layers) + + def get_spec(x): + if not isinstance(x, nnx.VariableState): + return None + if isinstance(x.value, nn.spmd.LogicallyPartitioned): + return x.value.partitions + metadata = x.get_metadata() + sharding = metadata.get("sharding") + if sharding and hasattr(sharding, "spec"): + return sharding.spec + return None + + return jax.tree.map(get_spec, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) + + def get_main_vmap_func_for_iterations(self): + def func_to_vmap(graph, state, stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode): + module = nnx.merge(graph, state) + out = module(stages_inputs, stages_segment_ids, stages_positions, deterministic, model_mode) + return out, nnx.state(module) + + return nnx.vmap( + func_to_vmap, + in_axes=(None, 0, 0, 0, 0, None, None), + out_axes=(0, 0), + axis_name=self.spmd_axis_name, + ) + + +class NNXPipeline(NNXPipelineBase): + """Original Pipeline implementation adapted for NNX.""" + + def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): + if self.config.num_pipeline_repeats > 1: + return self.get_current_repeat_from_stages( + pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + ) + return pipeline_weights + + def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): + """Fetches the weights for the current repeat from the stages.""" + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def gather_weights_for_stages_in(w, spec=None): + if w is None: + return None + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + return jax.tree.map(gather_weights_for_stages_in, weights) + + # Flatten specs to a list aligned with weights' leaf traversal order. + # Single-tree map on weights lets jax.tree.map naturally recurse into + # nnx.Variable nodes and create NEW Variables from results (no mutation), + # avoiding TraceContextError inside jax.lax.scan. + is_spec_leaf = lambda x: isinstance(x, P) or x is None + spec_leaves = jax.tree_util.tree_leaves(physical_partition_spec, is_leaf=is_spec_leaf) + spec_iter = iter(spec_leaves) + return jax.tree.map(lambda w: gather_weights_for_stages_in(w, next(spec_iter)), weights) + + def run_one_iteration( + self, + loop_state, + pipeline_weights_graph, + pipeline_weights_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec=None, + ): + """Executes the logic for a single microbatch iteration, including routing inputs and weights, and advancing buffers.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + stages_positions = self.vmap_gather(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = self.vmap_gather(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + + vmap_func = self.get_main_vmap_func_for_iterations() + + stage_weights_state = self.get_current_stage_weights( + pipeline_weights_state, loop_iteration, physical_partition_spec=physical_partition_spec + ) + + stages_output, updated_stage_weights_state = vmap_func( + pipeline_weights_graph, + stage_weights_state, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + + if self.config.scan_layers: + stages_output = stages_output[0] + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def _scatter_update(fw, uw, spec=None): + if fw is None or uw is None: + return fw + + def _update_one_stage(f_s, u_s, r_id): + return jax.lax.dynamic_update_slice_in_dim(f_s, jnp.expand_dims(u_s, 0), r_id, axis=0) + + r_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + updated_fw = jax.vmap(_update_one_stage, in_axes=(1, 0, 0), out_axes=1)(fw, uw, r_ids) + return self.shard_dim_by_stages(updated_fw, 1, physical_partition_spec=spec, is_stage_weight=False) + + pipeline_weights_state = jax.tree.map(_scatter_update, pipeline_weights_state, updated_stage_weights_state) + else: + pipeline_weights_state = updated_stage_weights_state + + new_state = self.advance_circular_buffers(stages_output, loop_state) + return new_state, pipeline_weights_state + + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=MODEL_MODE_TRAIN, + logical_partition_spec=None, + ) -> jnp.ndarray: + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ), + out_sharding=self.input_sharding, + ) + ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) + if positions is not None: + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + if segment_ids is not None: + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + loop_state = self.init_states(inputs) + + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + total_iterations = real_iterations + bubble_iterations + + logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) + + layers_graph, layers_state = nnx.split(self.layers) + + def is_lp(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def unbox_val(x): + return x.value if is_lp(x) else x + + layers_state = jax.tree.map(unbox_val, layers_state, is_leaf=is_lp) + + if self.config.pipeline_fsdp_ag_once: + layers_state = self.all_gather_over_fsdp(layers_state, logical_partition_spec) + + def is_static_param(path, v): + return isinstance(v, nnx.Param) or type(v).__name__ == "_overwrite_with_gradient" + + _, layers_params, layers_metrics, layers_mutables = nnx.split(layers_state, is_static_param, nnx.Intermediate, ...) + + def scan_body(carry, _): + current_loop_state, current_layer_mutables = carry + current_layer_state = nnx.State.merge(layers_params, layers_metrics, current_layer_mutables) + + new_loop_state, new_layer_state = self.run_one_iteration( + current_loop_state, + layers_graph, + current_layer_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec, + ) + + _, _, new_layer_metrics, new_layer_mutables = nnx.split(new_layer_state, is_static_param, nnx.Intermediate, ...) + return (new_loop_state, new_layer_mutables), new_layer_metrics + + if self.config.set_remat_policy_on_pipeline_iterations: + scan_body = jax.checkpoint( + scan_body, policy=self.get_pipeline_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations + ) + + if self.config.scan_pipeline_iterations: + (loop_state, final_layer_mutables), stacked_metrics = jax.lax.scan( + scan_body, (loop_state, layers_mutables), None, length=total_iterations + ) + else: + current_carry = (loop_state, layers_mutables) + metrics_history = [] + for _ in range(total_iterations): + current_carry, step_metrics = scan_body(current_carry, None) + metrics_history.append(step_metrics) + loop_state, final_layer_mutables = current_carry + stacked_metrics = jax.tree.map(lambda *xs: jnp.stack(xs), *metrics_history) if metrics_history else layers_metrics + + final_layer_state = nnx.State.merge(layers_params, stacked_metrics, final_layer_mutables) + self.layers = nnx.merge(layers_graph, final_layer_state) + + final_output = self.permute_output_micro_per_stage_dim(loop_state["state_io"]) + return jnp.reshape( + final_output, + (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), + out_sharding=self.output_sharding, + ) + + +class NNXCircularPipeline(NNXPipelineBase): + """NNX Implementation of a circular pipeline schedule with asynchronous weight prefetching. + + Inherits directly from NNXPipelineBase to leverage its native nnx.vmap setup for + pipeline variables (stages and circular_repeats). Uses pure jax.lax.scan and + jax.checkpoint for execution to avoid banned NNX execution transforms. + """ + + def init_states(self, inputs): + """Initializes pipeline execution state and Empty BSW buffers.""" + loop_state = super().init_states(inputs) + + weights = nnx.state(self.layers) + + def get_single_repeat_shape(x): + if x is None: + return None + return jnp.zeros_like(x[0]) if self.config.num_pipeline_repeats > 1 else jnp.zeros_like(x) + + bsw = ( + jax.tree.map(get_single_repeat_shape, weights), + jax.tree.map(get_single_repeat_shape, weights), + ) + + return loop_state, bsw + + def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): + """Slices out the specific sequence inputs (e.g., positions, segments) for the current microbatch.""" + if xs is None: + return None + + xs = jnp.asarray(xs) + ndim = xs.ndim + + def _gather_one(x, i): + idx = tuple(i if d == ids_dim else slice(None) for d in range(ndim)) + replicated_sharding = NamedSharding(self.mesh, P()) + return x.at[idx].get(out_sharding=replicated_sharding) + + return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) + + def gather_weights_across_stages_vmap(self, weights_state, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + """Uses jax.vmap to dynamically slice and gather weights for specific pipeline repeats.""" + + def _gather_repeat_leaf(w_leaf, rep_id): + if w_leaf is None: + return None + return jnp.squeeze( + jax.lax.dynamic_slice_in_dim(w_leaf, rep_id, 1, axis=repeat_dim_in_weights), axis=repeat_dim_in_weights + ) + + vmap_gather = jax.vmap(_gather_repeat_leaf, in_axes=(stages_dim_in_weights, 0), out_axes=0) + return jax.tree.map(lambda w: vmap_gather(w, repeat_ids) if w is not None else None, weights_state) + + def from_all_variables_to_repeat_weights(self, weights_state, loop_iteration): + """Slices out the specific repeat's weights from the full weights state.""" + if self.config.num_pipeline_repeats == 1: + return weights_state + + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + return self.gather_weights_across_stages_vmap( + weights_state, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 + ) + + def from_repeat_weights_to_bsw(self, repeat_weights, physical_partition_spec): + """Executes FSDP-like all-gathers to fully materialize a block of weights for BSW.""" + axes_to_remove = ["fsdp", "fsdp_transpose"] + if physical_partition_spec is not None: + bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove) + else: + bsw_pps = None + + def _apply_sharding_hint(weight, pspec): + if pspec is None or weight is None: + return weight + sharding_name = NamedSharding(self.mesh, pspec) + return maybe_shard_with_name( + weight, + sharding_name, + shard_mode=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, + extra_stack_level=0, + ) + + if bsw_pps is None: + return repeat_weights + + # Flatten specs to a list aligned with repeat_weights' leaf traversal order. + # Single-tree map avoids nnx.Variable mutation (TraceContextError inside scan). + is_spec_leaf = lambda x: isinstance(x, P) or x is None + spec_leaves = jax.tree_util.tree_leaves(bsw_pps, is_leaf=is_spec_leaf) + spec_iter = iter(spec_leaves) + return jax.tree.map(lambda w: _apply_sharding_hint(w, next(spec_iter)), repeat_weights) + + def weight_prefetching(self, weights_state, physical_partition_spec, loop_iteration): + """Triggers asynchronous FSDP-like all-gathers for current and next pipeline steps.""" + cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights_state, loop_iteration) + nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights_state, loop_iteration + 1) + bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec) + bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec) + return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") + + def fetch_active_stage_weights(self, bsw, loop_iteration, physical_partition_spec=None): + """The module fetches the actively prefetched weights + from the Buffer Sliding Window to avoid mid-iteration FSDP all-gathers. + """ + return self.get_current_weights_from_bsw(bsw, loop_iteration, physical_partition_spec) + + def get_current_weights_from_bsw(self, bsw, loop_iteration, physical_partition_spec): + """Pulls the fully gathered parameters for the current repeat from the BSW dual-buffer.""" + bsw_pps = jax.tree.map(self._remove_fsdp_from_physical_partition_spec, physical_partition_spec) + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + stage0_repeat_id = jnp.maximum(loop_iteration, 0) // self.config.num_pipeline_microbatches + + if bsw_pps is not None: + + @jax.shard_map(mesh=self.mesh, in_specs=((bsw_pps, bsw_pps), P("stage")), out_specs=bsw_pps, check_vma=True) + def select_weights_from_bsw(bsw_inner, repeat_id): + return jax.tree.map( + lambda x, y: jax.lax.select(repeat_id[0] == stage0_repeat_id, y, x) if x is not None else None, + bsw_inner[0], + bsw_inner[1], + ) + + weights = select_weights_from_bsw(bsw, repeat_ids) + else: + + def select_weights_from_bsw(bsw_inner, repeat_id): + return jax.tree.map( + lambda x, y: jax.lax.select(repeat_id == stage0_repeat_id, y, x) if x is not None else None, + bsw_inner[0], + bsw_inner[1], + ) + + weights = jax.vmap(select_weights_from_bsw, in_axes=((0, 0), 0), out_axes=0)(bsw, repeat_ids) + + return weights + + def run_one_iteration( + self, + loop_state, + bsw, + pipeline_weights_graph, + pipeline_weights_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec, + ): + """Executes the forward/backward logic for a single microbatch inside the circular pipeline.""" + state_io = loop_state["state_io"] + shift = loop_state["shift"] + circ_storage = loop_state["circ_storage"] + loop_iteration = loop_state["loop_iteration"] + + microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + + stages_positions = self.gather_microbatch_inputs_vmap(positions, microbatch_ids, 0) if positions is not None else None + stages_segment_ids = ( + self.gather_microbatch_inputs_vmap(segment_ids, microbatch_ids, 0) if segment_ids is not None else None + ) + + vmap_func = self.get_main_vmap_func_for_iterations() + stage_weights_state = self.fetch_active_stage_weights( + bsw, + loop_iteration, + physical_partition_spec=physical_partition_spec, + ) + + stages_output, updated_stage_weights_state = vmap_func( + pipeline_weights_graph, + stage_weights_state, + stages_inputs, + stages_segment_ids, + stages_positions, + deterministic, + model_mode, + ) + + if self.config.scan_layers: + stages_output = stages_output[0] + + if self.config.num_pipeline_repeats > 1: + _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) + + def _scatter_update(fw, uw, spec=None): + if fw is None or uw is None: + return fw + + def _update_one_stage(f_s, u_s, r_id): + return jax.lax.dynamic_update_slice_in_dim(f_s, jnp.expand_dims(u_s, 0), r_id, axis=0) + + r_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + updated_fw = jax.vmap(_update_one_stage, in_axes=(1, 0, 0), out_axes=1)(fw, uw, r_ids) + return self.shard_dim_by_stages(updated_fw, 1, physical_partition_spec=spec, is_stage_weight=False) + + pipeline_weights_state = jax.tree.map(_scatter_update, pipeline_weights_state, updated_stage_weights_state) + else: + pipeline_weights_state = updated_stage_weights_state + + new_state = self.advance_circular_buffers(stages_output, loop_state) + return new_state, pipeline_weights_state + + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: jnp.ndarray, + positions: jnp.ndarray, + deterministic: bool, + model_mode=MODEL_MODE_TRAIN, + logical_partition_spec=None, + ) -> jnp.ndarray: + inputs = inputs.reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ), + out_sharding=self.input_sharding, + ) + + ag_sharding = NamedSharding(self.mesh, P(None, None)) + if positions is not None: + positions = self._maybe_shard_with_name(positions, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + if segment_ids is not None: + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + loop_state, bsw = self.init_states(inputs) + + physical_partition_spec = logical_to_mesh( + logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) + + bubble_iterations = self.forwarding_delay * (self.num_stages - 1) + real_iterations = self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats + total_iterations = real_iterations + bubble_iterations + + layers_graph, layers_state = nnx.split(self.layers) + + def is_lp(x): + return isinstance(x, nn.spmd.LogicallyPartitioned) + + def unbox_val(x): + return x.value if is_lp(x) else x + + layers_state = jax.tree.map(unbox_val, layers_state, is_leaf=is_lp) + + def is_static_param(path, v): + return isinstance(v, nnx.Param) or type(v).__name__ == "_overwrite_with_gradient" + + _, layers_params, layers_metrics, layers_mutables = nnx.split(layers_state, is_static_param, nnx.Intermediate, ...) + + def scan_body(carry, _): + current_loop_state, _, current_layer_mutables = carry + current_layer_state = nnx.State.merge(layers_params, layers_metrics, current_layer_mutables) + + # 1. Async FSDP Prefetch into Buffer Sliding Window + next_bsw = self.weight_prefetching( + current_layer_state, physical_partition_spec, current_loop_state["loop_iteration"] + ) + + # 2. Run Forward & State Shift + new_loop_state, new_layer_state = self.run_one_iteration( + current_loop_state, + next_bsw, + layers_graph, + current_layer_state, + positions, + segment_ids, + deterministic, + model_mode, + logical_partition_spec, + ) + + _, _, new_layer_metrics, new_layer_mutables = nnx.split(new_layer_state, is_static_param, nnx.Intermediate, ...) + return (new_loop_state, next_bsw, new_layer_mutables), new_layer_metrics + + if self.config.set_remat_policy_on_pipeline_iterations: + scan_body = jax.checkpoint( + scan_body, policy=self.get_pipeline_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations + ) + + # Memory Efficient Execution via pure JAX scan + if self.config.scan_pipeline_iterations: + (loop_state, bsw, final_layer_mutables), stacked_metrics = jax.lax.scan( + scan_body, (loop_state, bsw, layers_mutables), None, length=total_iterations + ) + else: + current_carry = (loop_state, bsw, layers_mutables) + metrics_history = [] + for _ in range(total_iterations): + current_carry, step_metrics = scan_body(current_carry, None) + metrics_history.append(step_metrics) + loop_state, bsw, final_layer_mutables = current_carry + stacked_metrics = jax.tree.map(lambda *xs: jnp.stack(xs), *metrics_history) if metrics_history else layers_metrics + + final_layer_state = nnx.State.merge(layers_params, stacked_metrics, final_layer_mutables) + nnx.update(self.layers, final_layer_state) + + final_output = self.realign_output_microbatches(loop_state["state_io"]) + return jnp.reshape( + final_output, + (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), + out_sharding=self.output_sharding, + ) + + +def create_nnx_pipeline( + config: Config, stage_factory: Any, mesh: Mesh, remat_policy: Any = None, *, rngs: nnx.Rngs +) -> NNXPipeline | NNXCircularPipeline: + """Factory function to instantiate the NNX Pipeline module.""" + if config.pipeline_fsdp_ag_per_repeat: + return NNXCircularPipeline( + config=config, stage_factory=stage_factory, mesh=mesh, remat_policy=remat_policy, rngs=rngs + ) + return NNXPipeline(config=config, stage_factory=stage_factory, mesh=mesh, remat_policy=remat_policy, rngs=rngs)