diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 77751479ce..73f534d2ae 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -567,10 +567,12 @@ num_slices: -1 # Vocab Tiling Configs # Enables a memory-saving optimization by computing the cross-entropy loss in chunks. -# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis, -# reducing peak memory usage. This is highly recommended for models with large -# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable. +# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis, +# and `num_of_batch_tiling` parts along the batch-sequence axis, reducing peak memory usage. +# This is highly recommended for models with large vocabularies (e.g., Gemma). +# Set to a value greater than 1 to enable. num_vocab_tiling: 1 +num_of_batch_tiling: 1 # Tokenizer vocab_size: 32_000 # powers of 2 for sharding diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 4ede985e65..5527c502b6 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -194,9 +194,20 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) - ) -def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool): +def validate_vocab_tiling( + num_vocab_tiling: int, + num_of_batch_tiling: int, + per_device_batch_size: int, + max_target_length: int, + enable_nnx: bool, +): if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0: raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") + if (per_device_batch_size * max_target_length) % num_of_batch_tiling != 0: + raise ValueError( + "Per device batch size times sequence length should be divisible by the" + " number of batch tiles." + ) if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration raise ValueError("We currently don't support vocab tiling on NNX module.") diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 395b4dba06..6985dab2fc 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -953,7 +953,17 @@ class Tokenizer(BaseModel): ) num_vocab_tiling: int = Field( 1, - description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.", + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the vocabulary axis. >1 to enable." + ), + ) + num_of_batch_tiling: int = Field( + 1, + description=( + "Enables memory-saving optimization by tiling cross-entropy loss" + " computation along the batch-sequence axis. >1 to enable." + ), ) @@ -2461,6 +2471,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0 ): raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.") + if ( + self.per_device_batch_size > 0 + and (self.per_device_batch_size * self.max_target_length) + % self.num_of_batch_tiling + != 0 + ): + raise ValueError( + "Per device batch size times sequence length should be divisible by" + " the number of batch tiles." + ) if self.num_vocab_tiling > 1 and self.enable_nnx: raise ValueError("We currently don't support vocab tiling on NNX module.") if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring": diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index ce69b7a396..34ff491160 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -678,9 +678,8 @@ def _apply_embedding( return y @nn.compact - def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): - """Applies final normalization and projects hidden states to logits.""" - + def normalize_hidden_states(self, y, deterministic, model_mode): + """Applies final normalization and dropout to hidden states.""" cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) @@ -696,6 +695,20 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi parameter_memory_host_offload=cfg.parameter_memory_host_offload, )(y, out_sharding=norm_out_sharding) y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + return y + + @nn.compact + def apply_output_head( + self, + shared_embedding: nn.Module | nnx.Module, + y, + deterministic, + model_mode, + ): + """Applies final normalization and projects hidden states to logits.""" + + cfg = self.config + y = self.normalize_hidden_states(y, deterministic, model_mode) if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index 0d1fcab700..2d32ad3af6 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -118,6 +118,14 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode): ) return logits + def normalize_hidden_states(self, hidden_states, deterministic, model_mode): + """Normalize hidden states (wrapping decoder.normalize_hidden_states).""" + return self.decoder.normalize_hidden_states( + y=hidden_states, + deterministic=deterministic, + model_mode=model_mode, + ) + def __call__( self, decoder_input_tokens: jnp.ndarray, diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index ec68e9bc78..4cfe37341f 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -115,6 +115,153 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat (total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation) return total_loss, total_z_loss + def _b_v_chunked_cross_entropy_loss_fwd( + gathered_params, hidden_states, labels, segmentation + ): + batch_size, seq_len, emb_dim = hidden_states.shape + v_dim = config.vocab_size + + b_dim = batch_size * seq_len + b_block_sz = b_dim // config.num_of_batch_tiling + v_block_sz = v_dim // config.num_vocab_tiling + + if b_dim % b_block_sz != 0 or v_dim % v_block_sz != 0: + raise ValueError( + "Batch/sequence dimension and vocab dimension must be divisible by" + " their block sizes." + ) + + num_b_blocks = b_dim // b_block_sz + num_v_blocks = v_dim // v_block_sz + + flat_hidden = _reshape( + hidden_states, + (b_dim, emb_dim), + create_sharding( + model.mesh, + ("activation_embed_and_logits_batch_sequence", "activation_embed"), + ), + ) + flat_labels = _reshape( + labels, + (b_dim,), + create_sharding( + model.mesh, ("activation_embed_and_logits_batch_sequence",) + ), + ) + flat_segmentation = _reshape( + segmentation, + (b_dim,), + create_sharding( + model.mesh, ("activation_embed_and_logits_batch_sequence",) + ), + ) + + if config.logits_via_embedding: + w = gathered_params["params"]["shared_embedding"]["embedding"] + else: + w = gathered_params["params"]["decoder"]["logits_dense"]["kernel"] + + def b_loop_body(i, carry): + total_loss, total_z_loss = carry + b_start = i * b_block_sz + + def v_loop_body(j, v_carry): + lse_b_, b_loss_sum_neg_logits_ = v_carry + v_start = j * v_block_sz + labels_b = jax.lax.dynamic_slice(flat_labels, (b_start,), (b_block_sz,)) + x_b = jax.lax.dynamic_slice( + flat_hidden, (b_start, 0), (b_block_sz, emb_dim) + ) + + # Apply normalization to the batch block + x_b_norm = model.apply( + {"params": gathered_params["params"]}, + x_b, + deterministic=deterministic, + method="normalize_hidden_states", + ) + x_b_norm = _maybe_shard_with_name(x_b_norm, chunked_hidden_spec) + + # Extract w_j + if config.logits_via_embedding: + # Attend on embedding table. Table is (vocab_size, emb_dim) + # Transpose to (emb_dim, vocab_size) + w_j = jax.lax.dynamic_slice(w.T, (0, v_start), (emb_dim, v_block_sz)) + else: + w_j = jax.lax.dynamic_slice(w, (0, v_start), (emb_dim, v_block_sz)) + + # Compute logits for the block + logits_bv = jnp.dot(x_b_norm, w_j) + + if config.logits_via_embedding and config.normalize_embedding_logits: + logits_bv = logits_bv / jnp.sqrt(emb_dim) + if config.final_logits_soft_cap: + logits_bv = logits_bv / config.final_logits_soft_cap + logits_bv = jnp.tanh(logits_bv) * config.final_logits_soft_cap + + if config.cast_logits_to_fp32: + logits_bv = logits_bv.astype(jnp.float32) + + lse_b__ = jnp.logaddexp(lse_b_, jax.nn.logsumexp(logits_bv, axis=-1)) + + labels_one_hot = jax.nn.one_hot( + labels_b - v_start, v_block_sz, dtype=logits_bv.dtype + ) + b_loss_sum_neg_logits__ = b_loss_sum_neg_logits_ - jnp.sum( + logits_bv * labels_one_hot, axis=-1 + ) + return lse_b__, b_loss_sum_neg_logits__ + + lse_b, b_loss_sum_neg_logits = jax.lax.fori_loop( + 0, + num_v_blocks, + v_loop_body, + ( + jnp.full((b_block_sz,), -jnp.inf, dtype=jnp.float32), + jnp.zeros((b_block_sz,), dtype=jnp.float32), + ), + ) + + segmentation_b = jax.lax.dynamic_slice( + flat_segmentation, (b_start,), (b_block_sz,) + ) + mask = (segmentation_b != 0).astype(jnp.float32) + + # Z-loss + z_loss_b = config.z_loss_multiplier * jnp.square(lse_b) * mask + total_z_loss += jnp.sum(z_loss_b) + + b_loss_sum_neg_logits = b_loss_sum_neg_logits * mask + lse_b_masked = lse_b * mask + + total_loss += jnp.sum(b_loss_sum_neg_logits) + jnp.sum(lse_b_masked) + + return total_loss, total_z_loss + + initial_acc = (0.0, 0.0) + total_loss, total_z_loss = jax.lax.fori_loop( + 0, + num_b_blocks, + b_loop_body, + initial_acc, + ) + + # For drop-in replacement, we return residuals as the current method does. + # We pack necessary values for the backward pass. + # Note that the backward pass would also need to be implemented for this method + # to be fully compatible with jax.custom_vjp. + residuals = ( + gathered_params, + flat_hidden, + flat_labels, + flat_segmentation, + batch_size, + seq_len, + emb_dim, + ) + return (total_loss, total_z_loss), residuals + def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation): batch_size, seq_len, emb_dim = hidden_states.shape vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling