Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions src/modalities/models/gpt2/llama3_like_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import math
import re
from functools import partial
from typing import Annotated

import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.utils.logger_utils import get_logger

logger = get_logger(name="llama3 initialization")


class Llama3InitializerConfig(BaseModel):
num_layers: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
bias: bool


class Llama3Initializer(ModelInitializationIF):
"""
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
"""

def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None:
super().__init__()

self.regex_to_init = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need regex patterns for attention_norm, ffn_norm, and the final lm_head_normnai ?. Something like

r"transformer\.h\.\d+\.(attention_norm|ffn_norm)\.weight": nn.init.ones_,
r"transformer\.lm_head_norm\.weight": nn.init.ones_,

Copy link
Copy Markdown
Member Author

@le1nux le1nux Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module.reset_parameters()

we already call this here.

and due to recursion we also call it for the RMSNorm.
https://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407

# embedding weights
r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1),
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
# lm head weights
r"transformer\.lm_head\.weight": partial(
nn.init.trunc_normal_,
mean=0.0,
std=1 / math.sqrt(n_embd),
a=-3 / math.sqrt(n_embd),
b=3 / math.sqrt(n_embd),
),
# qkv projections
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": partial(
nn.init.trunc_normal_,
mean=0.0,
std=0.02,
a=-2,
b=2,
),
# final attention projection in attention block
r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This corresponds to following right ?, but in there you can see for out projection its std=init_std , which can be intialized differently and defaults to depth_init , because here we pass weight_init_std , which default to depth_init in titan here. If we dont want depth init then it matches scaled out_projections logic when depth_init is False for titan

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented depth_init to be fully compliant

nn.init.trunc_normal_,
mean=0.0,
std=0.02 / math.sqrt(2 * num_layers),
a=-2,
b=2,
),
# SwiGLU
r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial(
Comment thread
le1nux marked this conversation as resolved.
Outdated
nn.init.trunc_normal_,
mean=0.0,
std=0.02,
a=-2,
b=2,
),
r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial(
nn.init.trunc_normal_,
mean=0.0,
std=0.02 / math.sqrt(2 * num_layers),
a=-2,
b=2,
),
}
if bias:
Comment thread
BlueCrescent marked this conversation as resolved.
Outdated
self.regex_to_init = {
**self.regex_to_init,
**{
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial(
nn.init.trunc_normal_,
mean=0.0,
std=0.02,
a=-2,
b=2,
),
r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial(
nn.init.trunc_normal_,
mean=0.0,
std=0.02 / math.sqrt(2 * num_layers),
a=-2,
b=2,
),
r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_,
r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_,
},
}

def initialize_in_place(self, model: nn.Module):
self._init_by_fqn_regex(model, self.regex_to_init)

@staticmethod
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]):
hits = {k: 0 for k in regex_to_init.keys()}

for parameter_name, p in model.named_parameters():
match_count = 0
for weight_regex in regex_to_init.keys():
if re.fullmatch(weight_regex, parameter_name):
init_fn = regex_to_init[weight_regex]
init_fn(p)
match_count += 1
hits[weight_regex] += 1
if match_count == 0:
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a flag which turns this into an error?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the norms are initialized within the model factory via reset_parametersthis would always throw an error.

elif match_count > 1:
raise ValueError(
f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed"
)

for k, count in hits.items():
if count == 0:
raise ValueError(
f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3."
)
7 changes: 7 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
)
from modalities.models.gpt2.collator import GPT2LLMCollateFn
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
Expand Down Expand Up @@ -240,6 +241,12 @@ class ComponentEntity:
ComposedInitializationRoutines.get_composed_model_initializer,
ComposedModelInitializationConfig,
),
ComponentEntity(
"model_initialization",
"llama3_like",
Comment thread
le1nux marked this conversation as resolved.
Outdated
Llama3Initializer,
Llama3InitializerConfig,
),
# losses
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
# optimizers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ app_state_raw:
component_key: app_state
variant_key: raw
config:
model:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
optimizer:
Expand Down Expand Up @@ -288,7 +288,7 @@ optimizer:
eps: 1e-8
weight_decay: 1e-1
weight_decay_groups_excluded: [embedding, layernorm]
wrapped_model:
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE

Expand Down
111 changes: 109 additions & 2 deletions tests/test_initialization_fsdpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@
from torch.distributed.fsdp import StateDictType

from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.pydantic_if_types import (
PydanticFSDP1ModuleType,
PydanticFSDP2ModuleType,
PydanticPytorchModuleType,
)
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv


Expand Down Expand Up @@ -493,3 +501,102 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]:
model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
)[0]
return model_state


class TestLlama3LikeInitialization:
@pytest.mark.parametrize("has_bias", [True, False])
def test_llama3_like_initialization(self, has_bias: bool):
config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml"
n_layer = 4
n_embd = 256
model = self._get_components(config_file_path=config_file_path, has_bias=has_bias)
Comment thread
le1nux marked this conversation as resolved.
Outdated
self._test_wte(model=model)
self._test_lm_head(model=model, n_embd=n_embd)

for _, block in model.transformer["h"].items():
self._test_qkv_proj(gpt2_block=block, has_bias=has_bias)
self._test_c_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer)
self._test_swiglu_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer)

def _get_components(self, config_file_path: Path, has_bias: bool) -> GPT2LLM:
config_dict = load_app_config_dict(
config_file_path=config_file_path,
)
config_dict["model_raw"]["config"]["bias"] = has_bias
config_dict["initialized_model"]["config"]["model_initializer"]["config"]["bias"] = has_bias
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)

class ComponentsInstantiationModel(BaseModel):
initialized_model: PydanticPytorchModuleType

components: ComponentsInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=ComponentsInstantiationModel
)
return components.initialized_model

def _test_wte(self, model: GPT2LLM):
assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-2)
assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-2)

def _test_lm_head(self, model: GPT2LLM, n_embd: int):
assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_embd), abs=1e-3)
assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_embd)
assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd)
assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

def _test_qkv_proj(self, gpt2_block: GPT2Block, has_bias: bool):
layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn)
for layer in layers:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-2)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

if has_bias:
assert layer.bias is not None
assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-2)
assert layer.bias.max().detach().cpu() <= 2
assert layer.bias.min().detach().cpu() >= -2
else:
assert layer.bias is None

def _test_c_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int):
layer = gpt2_block.attn.c_proj
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-2)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

if has_bias:
assert layer.bias is not None
assert layer.bias.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
assert layer.bias.max().detach().cpu() <= 2
assert layer.bias.min().detach().cpu() >= -2
else:
assert layer.bias is None

def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int):
layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2)
for layer in layers:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

if has_bias:
# all zero bias
assert layer.bias is not None and torch.all(layer.bias == 0)
else:
assert layer.bias is None

layer = gpt2_block.mlp.W
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

if has_bias:
assert layer.bias is not None and torch.all(layer.bias == 0)
else:
assert layer.bias is None
60 changes: 60 additions & 0 deletions tests/test_yaml_configs/llama3_config_initalization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
initialized_model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: llama3_like
config:
num_layers: ${model_raw.config.n_layer}
n_embd: ${model_raw.config.n_embd}
bias: ${model_raw.config.bias}


model_raw:
component_key: model
variant_key: gpt2
config:
use_meta_device: true
use_weight_tying: false
sample_key: "input_ids"
poe_type: NOPE
sequence_length: 128
prediction_key: "logits"
vocab_size: 2048 # 2K vocab for testing
n_layer: 4 # 4 layers for testing
n_head_q: 32
n_head_kv: 8
ffn_hidden: 128 # 128 ffn hidden dim for testing
n_embd: 256 # 256 embedding dim for testing
dropout: 0.0
bias: true
attention_config:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q}
seq_length_dim: -2
base_freq: 500000
attention_implementation: pytorch_flash
activation_type: swiglu
attention_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05
ffn_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05
lm_head_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05