-
Notifications
You must be signed in to change notification settings - Fork 16
Llama3 like weight init #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5f5e616
4dea496
b704331
34a8621
c7bcaaa
2a171aa
43a9d50
2549e8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| import math | ||
| import re | ||
| from typing import Annotated, Callable | ||
|
|
||
| import torch.nn as nn | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from modalities.nn.model_initialization.initialization_if import ModelInitializationIF | ||
| from modalities.utils.logger_utils import get_logger | ||
|
|
||
| logger = get_logger(name="llama3 initialization") | ||
|
|
||
|
|
||
| class Llama3InitializerConfig(BaseModel): | ||
| num_layers: Annotated[int, Field(strict=True, gt=0)] | ||
| n_embd: Annotated[int, Field(strict=True, gt=0)] | ||
| depth_init: bool = True | ||
|
|
||
|
|
||
| class Llama3Initializer(ModelInitializationIF): | ||
| """ | ||
| Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. | ||
| """ | ||
|
|
||
| def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: | ||
| super().__init__() | ||
| self.depth_init = depth_init | ||
|
|
||
| self.regex_to_init = { | ||
| # embedding weights | ||
| r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}), | ||
| # lm head weights | ||
| r"transformer\.lm_head\.weight": ( | ||
| nn.init.trunc_normal_, | ||
| { | ||
| "mean": 0.0, | ||
| "std": 1 / math.sqrt(n_embd), | ||
| "a": -3 / math.sqrt(n_embd), | ||
| "b": 3 / math.sqrt(n_embd), | ||
| }, | ||
| ), | ||
| # qkv projections | ||
| r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": ( | ||
| nn.init.trunc_normal_, | ||
| { | ||
| "mean": 0.0, | ||
| "std": 0.02, | ||
| "a": -2, | ||
| "b": 2, | ||
| }, | ||
| ), | ||
| # final attention projection in attention block | ||
| r"transformer\.h\.\d+\.attn\.c_proj\.weight": ( | ||
| nn.init.trunc_normal_, | ||
| { | ||
| "mean": 0.0, | ||
| "std": ( | ||
| (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) | ||
| if depth_init | ||
| else 0.02 / math.sqrt(2 * num_layers) | ||
| ), | ||
| "a": -2, | ||
| "b": 2, | ||
| }, | ||
| ), | ||
| # SwiGLU | ||
| r"transformer\.h\.\d+\.mlp\.(W)\.weight": ( | ||
| nn.init.trunc_normal_, | ||
| { | ||
| "mean": 0.0, | ||
| "std": 0.02, | ||
| "a": -2, | ||
| "b": 2, | ||
| }, | ||
| ), | ||
| r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": ( | ||
| nn.init.trunc_normal_, | ||
| { | ||
| "mean": 0.0, | ||
| "std": ( | ||
| (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) | ||
| if depth_init | ||
| else 0.02 / math.sqrt(2 * num_layers) | ||
| ), | ||
| "a": -2, | ||
| "b": 2, | ||
| }, | ||
| ), | ||
| } | ||
|
|
||
| def initialize_in_place(self, model: nn.Module): | ||
| self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init) | ||
|
|
||
| @staticmethod | ||
| def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool): | ||
| hits = {k: 0 for k in regex_to_init.keys()} | ||
|
|
||
| for parameter_name, p in model.named_parameters(): | ||
| if parameter_name.endswith("bias"): | ||
| raise ValueError( | ||
| f"Bias initialization is not allowed for Llama3Initializer. Found bias parameter: {parameter_name}" | ||
| ) | ||
| match_count = 0 | ||
| for weight_regex in regex_to_init.keys(): | ||
| if re.fullmatch(weight_regex, parameter_name): | ||
| init_fn, arg_dict = regex_to_init[weight_regex] | ||
| if arg_dict["std"] is not None and callable(arg_dict["std"]): | ||
| # If std is a function, call it with the layer_id | ||
| layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name) | ||
| if layer_id_match is not None: | ||
| layer_id = int(layer_id_match.group(1)) | ||
| arg_dict = arg_dict.copy() # create a copy of the arg_dict to avoid mutating the original | ||
| arg_dict["std"] = arg_dict["std"](layer_id) | ||
| else: | ||
| raise ValueError( | ||
| f"Could not extract layer_id from parameter name {parameter_name} " | ||
| "for dynamic std calculation" | ||
| ) | ||
| init_fn(p, **arg_dict) | ||
| match_count += 1 | ||
| hits[weight_regex] += 1 | ||
|
|
||
| if match_count == 0: | ||
| logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a flag which turns this into an error?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the norms are initialized within the model factory via |
||
| elif match_count > 1: | ||
| raise ValueError( | ||
| f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed" | ||
| ) | ||
|
|
||
| for k, count in hits.items(): | ||
| if count == 0: | ||
| raise ValueError( | ||
| f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3." | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| initialized_model: | ||
| component_key: model | ||
| variant_key: model_initialized | ||
| config: | ||
| model: | ||
| instance_key: model_raw | ||
| pass_type: BY_REFERENCE | ||
| model_initializer: | ||
| component_key: model_initialization | ||
| variant_key: gpt2_llama3_like | ||
| config: | ||
| num_layers: ${model_raw.config.n_layer} | ||
| n_embd: ${model_raw.config.n_embd} | ||
| depth_init: False | ||
|
|
||
|
|
||
| model_raw: | ||
| component_key: model | ||
| variant_key: gpt2 | ||
| config: | ||
| use_meta_device: true | ||
| use_weight_tying: false | ||
| sample_key: "input_ids" | ||
| poe_type: NOPE | ||
| sequence_length: 128 | ||
| prediction_key: "logits" | ||
| vocab_size: 2048 # 2K vocab for testing | ||
| n_layer: 4 # 4 layers for testing | ||
| n_head_q: 32 | ||
| n_head_kv: 8 | ||
| ffn_hidden: 128 # 128 ffn hidden dim for testing | ||
| n_embd: 256 # 256 embedding dim for testing | ||
| dropout: 0.0 | ||
| bias: false | ||
| attention_config: | ||
| qkv_transforms: | ||
| - type_hint: RotaryTransform | ||
| config: | ||
| n_embd: ${model_raw.config.n_embd} | ||
| n_head: ${model_raw.config.n_head_q} | ||
| seq_length_dim: -2 | ||
| base_freq: 500000 | ||
| attention_implementation: pytorch_flash | ||
| activation_type: swiglu | ||
| attention_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
| ffn_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
| lm_head_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need regex patterns for
attention_norm,ffn_norm, and thefinal lm_head_normnai ?. Something likeUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modalities/src/modalities/models/model_factory.py
Line 269 in b704331
we already call this here.
and due to recursion we also call it for the RMSNorm.
https://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407