Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,8 @@ force_q_layout: false
mhc_expansion_rate: 1
# The number of iterations for the Sinkhorn-Knopp algorithm.
sinkhorn_iterations: 20
# Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.
enable_mhc_k4_shortcut: True

################################## DeepSeek Engram ##################################
# Indices of transformer layers where Engram are integrated; leave empty [] to disable.
Expand Down
68 changes: 50 additions & 18 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ class ModelArchitecture(BaseModel):
True,
description="Whether to apply scale on query and key normalizations (default True).",
)
v_norm_with_scale: bool = Field(True, description="Whether to apply scale on value normalization (default True).")
v_norm_with_scale: bool = Field(
True,
description="Whether to apply scale on value normalization (default True).",
)


class MTP(BaseModel):
Expand Down Expand Up @@ -685,14 +688,18 @@ class MoEGeneral(BaseModel):
num_experts: PositiveInt = Field(1, description="The total number of experts in each MoE layer.")
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
ragged_buffer_factor: float = Field(-1.0, description="Ragged buffer factor. If < 0, ragged buffer is worst case size.")
ragged_buffer_factor: float = Field(
-1.0,
description="Ragged buffer factor. If < 0, ragged buffer is worst case size.",
)
moe_expert_input_dim: int = Field(
-1,
description="Dimension of tokens entering the MoE layer. If < 0, defaults to emb_dim.",
)
base_moe_mlp_dim: int = Field(-1, description="Intermediate dimension at MoE layer.")
padded_base_moe_mlp_dim: Optional[int] = Field(
None, description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution."
None,
description="Padded intermediate dimension at MoE layer for efficient GMM_v2 kernel execution.",
)
load_balance_loss_weight: NonNegativeFloat = Field(0.0, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(
Expand Down Expand Up @@ -873,7 +880,8 @@ class HardwareAndMesh(BaseModel):
)
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")
custom_mesh_and_rule: CustomRule = Field(
CustomRule.DEFAULT, description="Customized mesh and logical rules for granularity."
CustomRule.DEFAULT,
description="Customized mesh and logical rules for granularity.",
)
allow_split_physical_axes: bool = Field(False, description="Allow splitting physical axes for device mesh creation.")
enable_nnx: bool = Field(False, description="Whether to use NNX for model definition.")
Expand All @@ -882,7 +890,8 @@ class HardwareAndMesh(BaseModel):
pure_nnx_decoder: bool = Field(False, description="Whether to enable pure NNX decoder.")
pure_nnx: bool = Field(False, description="Whether to enable pure NNX mode.")
remove_size_one_mesh_axis_from_type: bool = Field(
True, description="Whether to remove size one mesh axis from type through jax.config."
True,
description="Whether to remove size one mesh axis from type through jax.config.",
)


Expand All @@ -903,7 +912,10 @@ class LayoutAndSharding(BaseModel):
description="Allowed percentage of non-sharded parameters.",
)
shard_optimizer_over_data: bool = Field(False, description="Enable ZeRO-1 optimizer sharding over the data axis.")
internal_compile: bool = Field(False, description="Use internal_compile to bypass open-source topology mappings.")
internal_compile: bool = Field(
False,
description="Use internal_compile to bypass open-source topology mappings.",
)
internal_compile_num_devices: int = Field(-1, description="Number of devices when using internal_compile.")
compile_xla_flags: str = Field("", description="Compiler options for compilation only.")

Expand Down Expand Up @@ -950,7 +962,8 @@ class PipelineParallelism(BaseModel):
"""Configuration for pipeline parallelism."""

pipeline_fsdp_ag_per_repeat: bool = Field(
False, description="Enable weight prefetching for circular pipeline parallelism."
False,
description="Enable weight prefetching for circular pipeline parallelism.",
)
num_layers_per_pipeline_stage: int = Field(1, description="Number of layers to place on each pipeline stage.")
num_pipeline_repeats: int = Field(
Expand Down Expand Up @@ -1194,7 +1207,10 @@ class OlmoGrainDataset(BaseModel):
``data_shuffle_seed``); only OLMo-specific fields are listed here.
"""

olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
olmo_index_path: PathStr = Field(
"",
description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.",
)
olmo_path_remap_from: PathStr = Field(
"",
description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.",
Expand Down Expand Up @@ -1279,19 +1295,24 @@ class Distillation(BaseModel):
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.")
distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field(
"constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
"constant",
description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine').",
)
distill_temperature_end: Optional[float] = Field(
None, description="Target temperature at end of training. None keeps temperature fixed."
None,
description="Target temperature at end of training. None keeps temperature fixed.",
)
distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field(
"constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
"constant",
description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine').",
)
distill_beta_end: Optional[float] = Field(
None, description="Target beta_feature at end of training. None keeps beta fixed."
None,
description="Target beta_feature at end of training. None keeps beta fixed.",
)
distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field(
"constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
"constant",
description="Schedule type for beta annealing ('constant', 'linear', or 'cosine').",
)

# --- Learn to init related parameters --
Expand All @@ -1314,11 +1335,13 @@ class Distillation(BaseModel):
)

attn_module_name: Optional[str] = Field(
None, description="Attention nnx module attribute name to augment with LTI logic"
None,
description="Attention nnx module attribute name to augment with LTI logic",
)

lti_layer_indices: Optional[list[int]] = Field(
None, description="List of layer indices to apply LTI modifications. If None, applied to all layers."
None,
description="List of layer indices to apply LTI modifications. If None, applied to all layers.",
)
# ---------------------------------------

Expand Down Expand Up @@ -1365,6 +1388,10 @@ class ManifoldConstrainedHyperConnections(BaseModel):

mhc_expansion_rate: PositiveInt = Field(1, description="The number of parallel streams in Hyper Connection.")
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
enable_mhc_k4_shortcut: bool = Field(
True,
description="Whether to enable the permutation-based convex combination shortcut when mhc_expansion_rate is 4.",
)


class DilocoParams(BaseModel):
Expand Down Expand Up @@ -1655,7 +1682,8 @@ class Profiling(BaseModel):
tpu_num_chips_to_profile_per_task: int = Field(1, description="Specifies the number of TPU chips to profile per task.")
tpu_num_sparse_cores_to_trace: int = Field(2, description="Specifies the number of TPU chips to profile per task.")
tpu_num_sparse_core_tiles_to_trace: int = Field(
1, description="Specifies the number of tiles within each sparse core to trace on the TPU."
1,
description="Specifies the number of tiles within each sparse core to trace on the TPU.",
)
xprof_tpu_power_trace_level: XProfTPUPowerTraceMode = Field(
XProfTPUPowerTraceMode.POWER_TRACE_NONE,
Expand Down Expand Up @@ -2491,7 +2519,11 @@ def validate_and_set_hlo_dump_defaults():
)
for param_name, schedule, end_value in [
("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end),
("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end),
(
"distill_temperature",
self.distill_temperature_schedule,
self.distill_temperature_end,
),
("distill_beta", self.distill_beta_schedule, self.distill_beta_end),
]:
if schedule != "constant" and end_value is None:
Expand Down Expand Up @@ -3004,7 +3036,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = False

if self.use_batch_split_schedule:
if self.quantization and not self.quantization == "fp8_full":
if self.quantization and self.quantization != "fp8_full":
raise ValueError("Batch split quantization only supports `quantization=fp8_full`")

if self.opt_type == "muon" and self.decoder_block not in [
Expand Down
60 changes: 49 additions & 11 deletions src/maxtext/layers/mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,29 @@

"""DeepSeek Manifold-Constrained Hyper Connections (mHC) Layer."""

import itertools
from typing import Callable

from flax import nnx
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from maxtext.common.common_types import Array, Config
from maxtext.common.common_types import HyperConnectionType
from maxtext.layers.initializers import default_bias_init, default_scalar_init, nd_dense_init
from maxtext.layers.initializers import (
default_bias_init,
default_scalar_init,
nd_dense_init,
)
from maxtext.layers.normalizations import RMSNorm


def get_4x4_permutation_matrices():
perms = list(itertools.permutations(range(4)))
perms_array = jnp.array(perms)
return jnp.eye(4)[perms_array]


def get_functions(expansion_rate: int):
"""Creates functions to broadcast a single feature stream into multiple

Expand Down Expand Up @@ -118,6 +130,15 @@ def __init__(
out_sharding=(None,),
)

if self.k == 4 and self.config.enable_mhc_k4_shortcut:
res_out_dim = 24
res_beta_shape = (24,)
res_beta_sharding = (None,)
else:
res_out_dim = self.k * self.k
res_beta_shape = (self.k, self.k)
res_beta_sharding = (None, None)

# Weight matrices
scale_init = nd_dense_init(1.0, "fan_in", "normal")
in_axis = 0
Expand All @@ -126,7 +147,7 @@ def __init__(
self.res_alpha = nnx.Param(
scale_init(
self.rngs.params(),
(self.k * self.dim, self.k * self.k),
(self.k * self.dim, res_out_dim),
self.weight_dtype,
in_axis=in_axis,
out_axis=out_axis,
Expand Down Expand Up @@ -156,8 +177,8 @@ def __init__(

# Biases
self.res_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k, self.k), self.weight_dtype),
out_sharding=(None, None),
default_bias_init(self.rngs.params(), res_beta_shape, self.weight_dtype),
out_sharding=res_beta_sharding,
)
self.pre_beta = nnx.Param(
default_bias_init(self.rngs.params(), (self.k,), self.weight_dtype),
Expand All @@ -174,13 +195,30 @@ def res_mapping(self, x: Array):
res_alpha = jnp.asarray(self.res_alpha[...], self.dtype)
res_beta = jnp.asarray(self.res_beta[...], self.dtype)
res_alpha_scale = jnp.asarray(self.res_alpha_scale[...], self.dtype)
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

if self.k == 4 and self.config.enable_mhc_k4_shortcut:
# Apply projection: (b, s, k*d) @ (k*d, 24) -> (b, s, 24)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
intermediate = res_alpha_scale * h_res + res_beta[None, None, :]
# Use float32 for numerical stability during softmax
weights = jax.nn.softmax(intermediate.astype(jnp.float32), axis=-1).astype(self.dtype)
# Sum the 24 permutation matrices with the weights
permutation_matrices_4x4 = get_4x4_permutation_matrices().astype(self.dtype)
output = jnp.einsum(
"bsn,nkm -> bskm",
weights,
permutation_matrices_4x4,
precision=self.matmul_precision,
)
return output
else:
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (b, s, k*k)
h_res = jnp.einsum("bsm,mn -> bsn", x, res_alpha, precision=self.matmul_precision)
b, s, _ = h_res.shape
h_res = jnp.reshape(h_res, (b, s, self.k, self.k))
intermediate = res_alpha_scale * h_res + res_beta[None, None, :, :]
output = sinkhorn(intermediate, self.sinkhorn_iterations)
return output

def mapping(self, x: Array, alpha_scale: Array, alpha: Array, beta: Array, scale: int):
"""Helper function for both pre and post mappings."""
Expand Down
1 change: 0 additions & 1 deletion tests/unit/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ def setUp(self):
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.default"),
enable_checkpointing=False,
)
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)


@pytest.mark.external_training
Expand Down
Loading
Loading