Skip to content

Commit 6a17097

Browse files
authored
Merge pull request #435 from Modalities/llama3_like_weight_init
Llama3 like weight init
2 parents e97578d + 2549e8b commit 6a17097

5 files changed

Lines changed: 293 additions & 4 deletions

File tree

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import math
2+
import re
3+
from typing import Annotated, Callable
4+
5+
import torch.nn as nn
6+
from pydantic import BaseModel, Field
7+
8+
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
9+
from modalities.utils.logger_utils import get_logger
10+
11+
logger = get_logger(name="llama3 initialization")
12+
13+
14+
class Llama3InitializerConfig(BaseModel):
15+
num_layers: Annotated[int, Field(strict=True, gt=0)]
16+
n_embd: Annotated[int, Field(strict=True, gt=0)]
17+
depth_init: bool = True
18+
19+
20+
class Llama3Initializer(ModelInitializationIF):
21+
"""
22+
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
23+
"""
24+
25+
def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
26+
super().__init__()
27+
self.depth_init = depth_init
28+
29+
self.regex_to_init = {
30+
# embedding weights
31+
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
32+
# lm head weights
33+
r"transformer\.lm_head\.weight": (
34+
nn.init.trunc_normal_,
35+
{
36+
"mean": 0.0,
37+
"std": 1 / math.sqrt(n_embd),
38+
"a": -3 / math.sqrt(n_embd),
39+
"b": 3 / math.sqrt(n_embd),
40+
},
41+
),
42+
# qkv projections
43+
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
44+
nn.init.trunc_normal_,
45+
{
46+
"mean": 0.0,
47+
"std": 0.02,
48+
"a": -2,
49+
"b": 2,
50+
},
51+
),
52+
# final attention projection in attention block
53+
r"transformer\.h\.\d+\.attn\.c_proj\.weight": (
54+
nn.init.trunc_normal_,
55+
{
56+
"mean": 0.0,
57+
"std": (
58+
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
59+
if depth_init
60+
else 0.02 / math.sqrt(2 * num_layers)
61+
),
62+
"a": -2,
63+
"b": 2,
64+
},
65+
),
66+
# SwiGLU
67+
r"transformer\.h\.\d+\.mlp\.(W)\.weight": (
68+
nn.init.trunc_normal_,
69+
{
70+
"mean": 0.0,
71+
"std": 0.02,
72+
"a": -2,
73+
"b": 2,
74+
},
75+
),
76+
r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": (
77+
nn.init.trunc_normal_,
78+
{
79+
"mean": 0.0,
80+
"std": (
81+
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
82+
if depth_init
83+
else 0.02 / math.sqrt(2 * num_layers)
84+
),
85+
"a": -2,
86+
"b": 2,
87+
},
88+
),
89+
}
90+
91+
def initialize_in_place(self, model: nn.Module):
92+
self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init)
93+
94+
@staticmethod
95+
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool):
96+
hits = {k: 0 for k in regex_to_init.keys()}
97+
98+
for parameter_name, p in model.named_parameters():
99+
if parameter_name.endswith("bias"):
100+
raise ValueError(
101+
f"Bias initialization is not allowed for Llama3Initializer. Found bias parameter: {parameter_name}"
102+
)
103+
match_count = 0
104+
for weight_regex in regex_to_init.keys():
105+
if re.fullmatch(weight_regex, parameter_name):
106+
init_fn, arg_dict = regex_to_init[weight_regex]
107+
if arg_dict["std"] is not None and callable(arg_dict["std"]):
108+
# If std is a function, call it with the layer_id
109+
layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name)
110+
if layer_id_match is not None:
111+
layer_id = int(layer_id_match.group(1))
112+
arg_dict = arg_dict.copy() # create a copy of the arg_dict to avoid mutating the original
113+
arg_dict["std"] = arg_dict["std"](layer_id)
114+
else:
115+
raise ValueError(
116+
f"Could not extract layer_id from parameter name {parameter_name} "
117+
"for dynamic std calculation"
118+
)
119+
init_fn(p, **arg_dict)
120+
match_count += 1
121+
hits[weight_regex] += 1
122+
123+
if match_count == 0:
124+
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
125+
elif match_count > 1:
126+
raise ValueError(
127+
f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed"
128+
)
129+
130+
for k, count in hits.items():
131+
if count == 0:
132+
raise ValueError(
133+
f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3."
134+
)

src/modalities/registry/components.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
)
9393
from modalities.models.gpt2.collator import GPT2LLMCollateFn
9494
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
95+
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
9596
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
9697
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
9798
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
@@ -240,6 +241,12 @@ class ComponentEntity:
240241
ComposedInitializationRoutines.get_composed_model_initializer,
241242
ComposedModelInitializationConfig,
242243
),
244+
ComponentEntity(
245+
"model_initialization",
246+
"gpt2_llama3_like",
247+
Llama3Initializer,
248+
Llama3InitializerConfig,
249+
),
243250
# losses
244251
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
245252
# optimizers

tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ app_state_raw:
177177
component_key: app_state
178178
variant_key: raw
179179
config:
180-
model:
180+
model:
181181
instance_key: initialized_model
182182
pass_type: BY_REFERENCE
183183
optimizer:
@@ -288,7 +288,7 @@ optimizer:
288288
eps: 1e-8
289289
weight_decay: 1e-1
290290
weight_decay_groups_excluded: [embedding, layernorm]
291-
wrapped_model:
291+
wrapped_model:
292292
instance_key: initialized_model
293293
pass_type: BY_REFERENCE
294294

tests/test_initialization_fsdpx.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818
from torch.distributed.fsdp import StateDictType
1919

2020
from modalities.__main__ import Main
21-
from modalities.config.config import ProcessGroupBackendType
22-
from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType
21+
from modalities.config.component_factory import ComponentFactory
22+
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
23+
from modalities.config.pydantic_if_types import (
24+
PydanticFSDP1ModuleType,
25+
PydanticFSDP2ModuleType,
26+
PydanticPytorchModuleType,
27+
)
28+
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block
29+
from modalities.registry.components import COMPONENTS
30+
from modalities.registry.registry import Registry
2331
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
2432

2533

@@ -493,3 +501,83 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]:
493501
model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
494502
)[0]
495503
return model_state
504+
505+
506+
class TestLlama3LikeInitialization:
507+
@pytest.mark.parametrize("depth_init", [True, False])
508+
def test_llama3_like_initialization(self, depth_init: bool):
509+
config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml"
510+
n_layer = 4
511+
n_embd = 256
512+
model = self._get_components(config_file_path=config_file_path, depth_init=depth_init)
513+
self._test_wte(model=model)
514+
self._test_lm_head(model=model, n_embd=n_embd)
515+
516+
for layer_id, (_, block) in enumerate(model.transformer["h"].items()):
517+
self._test_qkv_proj(gpt2_block=block)
518+
self._test_c_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id)
519+
self._test_swiglu_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id)
520+
521+
def _get_components(self, config_file_path: Path, depth_init: bool) -> GPT2LLM:
522+
config_dict = load_app_config_dict(
523+
config_file_path=config_file_path,
524+
)
525+
config_dict["initialized_model"]["config"]["model_initializer"]["config"]["depth_init"] = depth_init
526+
registry = Registry(COMPONENTS)
527+
component_factory = ComponentFactory(registry=registry)
528+
529+
class ComponentsInstantiationModel(BaseModel):
530+
initialized_model: PydanticPytorchModuleType
531+
532+
components: ComponentsInstantiationModel = component_factory.build_components(
533+
config_dict=config_dict, components_model_type=ComponentsInstantiationModel
534+
)
535+
return components.initialized_model
536+
537+
def _test_wte(self, model: GPT2LLM):
538+
assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-2)
539+
assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-2)
540+
541+
def _test_lm_head(self, model: GPT2LLM, n_embd: int):
542+
assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_embd), abs=1e-3)
543+
assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_embd)
544+
assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd)
545+
assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
546+
547+
def _test_qkv_proj(self, gpt2_block: GPT2Block):
548+
layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn)
549+
for layer in layers:
550+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
551+
assert layer.weight.max().detach().cpu() <= 2
552+
assert layer.weight.min().detach().cpu() >= -2
553+
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
554+
555+
def _test_c_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int):
556+
layer = gpt2_block.attn.c_proj
557+
if depth_init:
558+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3)
559+
else:
560+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
561+
562+
assert layer.weight.max().detach().cpu() <= 2
563+
assert layer.weight.min().detach().cpu() >= -2
564+
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
565+
566+
def _test_swiglu_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int):
567+
layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2)
568+
for layer in layers:
569+
if depth_init:
570+
assert layer.weight.std().detach().cpu() == pytest.approx(
571+
0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3
572+
)
573+
else:
574+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
575+
assert layer.weight.max().detach().cpu() <= 2
576+
assert layer.weight.min().detach().cpu() >= -2
577+
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
578+
579+
layer = gpt2_block.mlp.W
580+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
581+
assert layer.weight.max().detach().cpu() <= 2
582+
assert layer.weight.min().detach().cpu() >= -2
583+
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
initialized_model:
2+
component_key: model
3+
variant_key: model_initialized
4+
config:
5+
model:
6+
instance_key: model_raw
7+
pass_type: BY_REFERENCE
8+
model_initializer:
9+
component_key: model_initialization
10+
variant_key: gpt2_llama3_like
11+
config:
12+
num_layers: ${model_raw.config.n_layer}
13+
n_embd: ${model_raw.config.n_embd}
14+
depth_init: False
15+
16+
17+
model_raw:
18+
component_key: model
19+
variant_key: gpt2
20+
config:
21+
use_meta_device: true
22+
use_weight_tying: false
23+
sample_key: "input_ids"
24+
poe_type: NOPE
25+
sequence_length: 128
26+
prediction_key: "logits"
27+
vocab_size: 2048 # 2K vocab for testing
28+
n_layer: 4 # 4 layers for testing
29+
n_head_q: 32
30+
n_head_kv: 8
31+
ffn_hidden: 128 # 128 ffn hidden dim for testing
32+
n_embd: 256 # 256 embedding dim for testing
33+
dropout: 0.0
34+
bias: false
35+
attention_config:
36+
qkv_transforms:
37+
- type_hint: RotaryTransform
38+
config:
39+
n_embd: ${model_raw.config.n_embd}
40+
n_head: ${model_raw.config.n_head_q}
41+
seq_length_dim: -2
42+
base_freq: 500000
43+
attention_implementation: pytorch_flash
44+
activation_type: swiglu
45+
attention_norm_config:
46+
norm_type: pytorch_rms_norm
47+
config:
48+
normalized_shape: ${model_raw.config.n_embd}
49+
eps: 1.0e-05
50+
ffn_norm_config:
51+
norm_type: pytorch_rms_norm
52+
config:
53+
normalized_shape: ${model_raw.config.n_embd}
54+
eps: 1.0e-05
55+
lm_head_norm_config:
56+
norm_type: pytorch_rms_norm
57+
config:
58+
normalized_shape: ${model_raw.config.n_embd}
59+
eps: 1.0e-05
60+

0 commit comments

Comments
 (0)