-
Notifications
You must be signed in to change notification settings - Fork 16
Llama3 like weight init #435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5f5e616
4dea496
b704331
34a8621
c7bcaaa
2a171aa
43a9d50
2549e8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| import math | ||
| import re | ||
| from functools import partial | ||
| from typing import Annotated | ||
|
|
||
| import torch.nn as nn | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from modalities.nn.model_initialization.initialization_if import ModelInitializationIF | ||
| from modalities.utils.logger_utils import get_logger | ||
|
|
||
| logger = get_logger(name="llama3 initialization") | ||
|
|
||
|
|
||
| class Llama3InitializerConfig(BaseModel): | ||
| num_layers: Annotated[int, Field(strict=True, gt=0)] | ||
| n_embd: Annotated[int, Field(strict=True, gt=0)] | ||
| bias: bool | ||
|
|
||
|
|
||
| class Llama3Initializer(ModelInitializationIF): | ||
| """ | ||
| Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. | ||
| """ | ||
|
|
||
| def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None: | ||
| super().__init__() | ||
|
|
||
| self.regex_to_init = { | ||
| # embedding weights | ||
| r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1), | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| # lm head weights | ||
| r"transformer\.lm_head\.weight": partial( | ||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=1 / math.sqrt(n_embd), | ||
| a=-3 / math.sqrt(n_embd), | ||
| b=3 / math.sqrt(n_embd), | ||
| ), | ||
| # qkv projections | ||
| r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": partial( | ||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02, | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| # final attention projection in attention block | ||
| r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This corresponds to following right ?, but in there you can see for out projection its
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I implemented depth_init to be fully compliant |
||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02 / math.sqrt(2 * num_layers), | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| # SwiGLU | ||
| r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial( | ||
|
le1nux marked this conversation as resolved.
Outdated
|
||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02, | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial( | ||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02 / math.sqrt(2 * num_layers), | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| } | ||
| if bias: | ||
|
BlueCrescent marked this conversation as resolved.
Outdated
|
||
| self.regex_to_init = { | ||
| **self.regex_to_init, | ||
| **{ | ||
| r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial( | ||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02, | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial( | ||
| nn.init.trunc_normal_, | ||
| mean=0.0, | ||
| std=0.02 / math.sqrt(2 * num_layers), | ||
| a=-2, | ||
| b=2, | ||
| ), | ||
| r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_, | ||
| r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_, | ||
| }, | ||
| } | ||
|
|
||
| def initialize_in_place(self, model: nn.Module): | ||
| self._init_by_fqn_regex(model, self.regex_to_init) | ||
|
|
||
| @staticmethod | ||
| def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]): | ||
| hits = {k: 0 for k in regex_to_init.keys()} | ||
|
|
||
| for parameter_name, p in model.named_parameters(): | ||
| match_count = 0 | ||
| for weight_regex in regex_to_init.keys(): | ||
| if re.fullmatch(weight_regex, parameter_name): | ||
| init_fn = regex_to_init[weight_regex] | ||
| init_fn(p) | ||
| match_count += 1 | ||
| hits[weight_regex] += 1 | ||
| if match_count == 0: | ||
| logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a flag which turns this into an error?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the norms are initialized within the model factory via |
||
| elif match_count > 1: | ||
| raise ValueError( | ||
| f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed" | ||
| ) | ||
|
|
||
| for k, count in hits.items(): | ||
| if count == 0: | ||
| raise ValueError( | ||
| f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3." | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| initialized_model: | ||
| component_key: model | ||
| variant_key: model_initialized | ||
| config: | ||
| model: | ||
| instance_key: model_raw | ||
| pass_type: BY_REFERENCE | ||
| model_initializer: | ||
| component_key: model_initialization | ||
| variant_key: llama3_like | ||
| config: | ||
| num_layers: ${model_raw.config.n_layer} | ||
| n_embd: ${model_raw.config.n_embd} | ||
| bias: ${model_raw.config.bias} | ||
|
|
||
|
|
||
| model_raw: | ||
| component_key: model | ||
| variant_key: gpt2 | ||
| config: | ||
| use_meta_device: true | ||
| use_weight_tying: false | ||
| sample_key: "input_ids" | ||
| poe_type: NOPE | ||
| sequence_length: 128 | ||
| prediction_key: "logits" | ||
| vocab_size: 2048 # 2K vocab for testing | ||
| n_layer: 4 # 4 layers for testing | ||
| n_head_q: 32 | ||
| n_head_kv: 8 | ||
| ffn_hidden: 128 # 128 ffn hidden dim for testing | ||
| n_embd: 256 # 256 embedding dim for testing | ||
| dropout: 0.0 | ||
| bias: true | ||
| attention_config: | ||
| qkv_transforms: | ||
| - type_hint: RotaryTransform | ||
| config: | ||
| n_embd: ${model_raw.config.n_embd} | ||
| n_head: ${model_raw.config.n_head_q} | ||
| seq_length_dim: -2 | ||
| base_freq: 500000 | ||
| attention_implementation: pytorch_flash | ||
| activation_type: swiglu | ||
| attention_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
| ffn_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
| lm_head_norm_config: | ||
| norm_type: pytorch_rms_norm | ||
| config: | ||
| normalized_shape: ${model_raw.config.n_embd} | ||
| eps: 1.0e-05 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also need regex patterns for
attention_norm,ffn_norm, and thefinal lm_head_normnai ?. Something likeUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modalities/src/modalities/models/model_factory.py
Line 269 in b704331
we already call this here.
and due to recursion we also call it for the RMSNorm.
https://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407