Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,12 @@ num_slices: -1

# Vocab Tiling Configs
# Enables a memory-saving optimization by computing the cross-entropy loss in chunks.
# The logits are tiled into `num_vocab_tiling` parts along the batch-sequence axis,
# reducing peak memory usage. This is highly recommended for models with large
# vocabularies (e.g., Gemma). Set to a value greater than 1 to enable.
# The logits are tiled into `num_vocab_tiling` parts along the vocabulary axis,
# and `num_of_batch_tiling` parts along the batch-sequence axis, reducing peak memory usage.
# This is highly recommended for models with large vocabularies (e.g., Gemma).
# Set to a value greater than 1 to enable.
num_vocab_tiling: 1
num_of_batch_tiling: 1

# Tokenizer
vocab_size: 32_000 # powers of 2 for sharding
Expand Down
13 changes: 12 additions & 1 deletion src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,20 @@ def validate_expert_shard_attention_option(expert_shard_attention_option: str) -
)


def validate_vocab_tiling(num_vocab_tiling: int, per_device_batch_size: int, max_target_length: int, enable_nnx: bool):
def validate_vocab_tiling(
num_vocab_tiling: int,
num_of_batch_tiling: int,
per_device_batch_size: int,
max_target_length: int,
enable_nnx: bool,
):
if (per_device_batch_size * max_target_length) % num_vocab_tiling != 0:
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if (per_device_batch_size * max_target_length) % num_of_batch_tiling != 0:
raise ValueError(
"Per device batch size times sequence length should be divisible by the"
" number of batch tiles."
)
if num_vocab_tiling > 1 and enable_nnx: # TODO (chengnuojin) enable vocab tiling on NNX after NNX migration
raise ValueError("We currently don't support vocab tiling on NNX module.")

Expand Down
22 changes: 21 additions & 1 deletion src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,17 @@ class Tokenizer(BaseModel):
)
num_vocab_tiling: int = Field(
1,
description="Enables memory-saving optimization by tiling cross-entropy loss computation. >1 to enable.",
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the vocabulary axis. >1 to enable."
),
)
num_of_batch_tiling: int = Field(
1,
description=(
"Enables memory-saving optimization by tiling cross-entropy loss"
" computation along the batch-sequence axis. >1 to enable."
),
)


Expand Down Expand Up @@ -2461,6 +2471,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
):
raise ValueError("Per device batch size times sequence length should be divisible by the number of vocab tiles.")
if (
self.per_device_batch_size > 0
and (self.per_device_batch_size * self.max_target_length)
% self.num_of_batch_tiling
!= 0
):
raise ValueError(
"Per device batch size times sequence length should be divisible by"
" the number of batch tiles."
)
if self.num_vocab_tiling > 1 and self.enable_nnx:
raise ValueError("We currently don't support vocab tiling on NNX module.")
if self.context_parallel_size > 1 and self.context_parallel_strategy.lower() == "ring":
Expand Down
19 changes: 16 additions & 3 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,8 @@ def _apply_embedding(
return y

@nn.compact
def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode):
"""Applies final normalization and projects hidden states to logits."""

def normalize_hidden_states(self, y, deterministic, model_mode):
"""Applies final normalization and dropout to hidden states."""
cfg = self.config
if cfg.shard_mode == ShardMode.EXPLICIT:
norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed"))
Expand All @@ -696,6 +695,20 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi
parameter_memory_host_offload=cfg.parameter_memory_host_offload,
)(y, out_sharding=norm_out_sharding)
y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic)
return y

@nn.compact
def apply_output_head(
self,
shared_embedding: nn.Module | nnx.Module,
y,
deterministic,
model_mode,
):
"""Applies final normalization and projects hidden states to logits."""

cfg = self.config
y = self.normalize_hidden_states(y, deterministic, model_mode)

if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab"))
Expand Down
8 changes: 8 additions & 0 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def logits_from_hidden_states(self, hidden_states, deterministic, model_mode):
)
return logits

def normalize_hidden_states(self, hidden_states, deterministic, model_mode):
"""Normalize hidden states (wrapping decoder.normalize_hidden_states)."""
return self.decoder.normalize_hidden_states(
y=hidden_states,
deterministic=deterministic,
model_mode=model_mode,
)

def __call__(
self,
decoder_input_tokens: jnp.ndarray,
Expand Down
147 changes: 147 additions & 0 deletions src/maxtext/utils/vocabulary_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,153 @@ def chunked_cross_entropy_loss(gathered_params, hidden_states, labels, segmentat
(total_loss, total_z_loss), _ = _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation)
return total_loss, total_z_loss

def _b_v_chunked_cross_entropy_loss_fwd(
gathered_params, hidden_states, labels, segmentation
):
batch_size, seq_len, emb_dim = hidden_states.shape
v_dim = config.vocab_size

b_dim = batch_size * seq_len
b_block_sz = b_dim // config.num_of_batch_tiling
v_block_sz = v_dim // config.num_vocab_tiling

if b_dim % b_block_sz != 0 or v_dim % v_block_sz != 0:
raise ValueError(
"Batch/sequence dimension and vocab dimension must be divisible by"
" their block sizes."
)

num_b_blocks = b_dim // b_block_sz
num_v_blocks = v_dim // v_block_sz

flat_hidden = _reshape(
hidden_states,
(b_dim, emb_dim),
create_sharding(
model.mesh,
("activation_embed_and_logits_batch_sequence", "activation_embed"),
),
)
flat_labels = _reshape(
labels,
(b_dim,),
create_sharding(
model.mesh, ("activation_embed_and_logits_batch_sequence",)
),
)
flat_segmentation = _reshape(
segmentation,
(b_dim,),
create_sharding(
model.mesh, ("activation_embed_and_logits_batch_sequence",)
),
)

if config.logits_via_embedding:
w = gathered_params["params"]["shared_embedding"]["embedding"]
else:
w = gathered_params["params"]["decoder"]["logits_dense"]["kernel"]

def b_loop_body(i, carry):
total_loss, total_z_loss = carry
b_start = i * b_block_sz

def v_loop_body(j, v_carry):
lse_b_, b_loss_sum_neg_logits_ = v_carry
v_start = j * v_block_sz
labels_b = jax.lax.dynamic_slice(flat_labels, (b_start,), (b_block_sz,))
x_b = jax.lax.dynamic_slice(
flat_hidden, (b_start, 0), (b_block_sz, emb_dim)
)

# Apply normalization to the batch block
x_b_norm = model.apply(
{"params": gathered_params["params"]},
x_b,
deterministic=deterministic,
method="normalize_hidden_states",
)
x_b_norm = _maybe_shard_with_name(x_b_norm, chunked_hidden_spec)

# Extract w_j
if config.logits_via_embedding:
# Attend on embedding table. Table is (vocab_size, emb_dim)
# Transpose to (emb_dim, vocab_size)
w_j = jax.lax.dynamic_slice(w.T, (0, v_start), (emb_dim, v_block_sz))
else:
w_j = jax.lax.dynamic_slice(w, (0, v_start), (emb_dim, v_block_sz))

# Compute logits for the block
logits_bv = jnp.dot(x_b_norm, w_j)

if config.logits_via_embedding and config.normalize_embedding_logits:
logits_bv = logits_bv / jnp.sqrt(emb_dim)
if config.final_logits_soft_cap:
logits_bv = logits_bv / config.final_logits_soft_cap
logits_bv = jnp.tanh(logits_bv) * config.final_logits_soft_cap

if config.cast_logits_to_fp32:
logits_bv = logits_bv.astype(jnp.float32)

lse_b__ = jnp.logaddexp(lse_b_, jax.nn.logsumexp(logits_bv, axis=-1))

labels_one_hot = jax.nn.one_hot(
labels_b - v_start, v_block_sz, dtype=logits_bv.dtype
)
b_loss_sum_neg_logits__ = b_loss_sum_neg_logits_ - jnp.sum(
logits_bv * labels_one_hot, axis=-1
)
return lse_b__, b_loss_sum_neg_logits__

lse_b, b_loss_sum_neg_logits = jax.lax.fori_loop(
0,
num_v_blocks,
v_loop_body,
(
jnp.full((b_block_sz,), -jnp.inf, dtype=jnp.float32),
jnp.zeros((b_block_sz,), dtype=jnp.float32),
),
)

segmentation_b = jax.lax.dynamic_slice(
flat_segmentation, (b_start,), (b_block_sz,)
)
mask = (segmentation_b != 0).astype(jnp.float32)

# Z-loss
z_loss_b = config.z_loss_multiplier * jnp.square(lse_b) * mask
total_z_loss += jnp.sum(z_loss_b)

b_loss_sum_neg_logits = b_loss_sum_neg_logits * mask
lse_b_masked = lse_b * mask

total_loss += jnp.sum(b_loss_sum_neg_logits) + jnp.sum(lse_b_masked)

return total_loss, total_z_loss

initial_acc = (0.0, 0.0)
total_loss, total_z_loss = jax.lax.fori_loop(
0,
num_b_blocks,
b_loop_body,
initial_acc,
)

# For drop-in replacement, we return residuals as the current method does.
# We pack necessary values for the backward pass.
# Note that the backward pass would also need to be implemented for this method
# to be fully compatible with jax.custom_vjp.
residuals = (
gathered_params,
flat_hidden,
flat_labels,
flat_segmentation,
batch_size,
seq_len,
emb_dim,
)
return (total_loss, total_z_loss), residuals

def _chunked_cross_entropy_loss_fwd(gathered_params, hidden_states, labels, segmentation):
batch_size, seq_len, emb_dim = hidden_states.shape
vocab_tile_size = (batch_size * seq_len) // config.num_vocab_tiling
Expand Down
Loading