11import math
2- from typing import Any , Dict , Optional , Union , Tuple
2+ from typing import Any , Dict , Optional , Tuple , Union
33
44import torch
55import torch .nn as nn
66import torch .nn .functional as F
77from transformers import GenerationMixin , PretrainedConfig , PreTrainedModel
88from transformers .modeling_outputs import CausalLMOutput
9-
9+ from transformers . modeling_layers import GradientCheckpointingLayer
1010from .attns import MHA
1111from .config import TransformerConfig
1212from .ffn import SwiGLU
1313from .pos import RoPE
14+ from .utils import check_type
1415
1516
16- class TransformerBlock (nn . Module ):
17+ class TransformerBlock (GradientCheckpointingLayer ):
1718 """
18- A Single Decoder Transformer Block consisting of Multi-Head Attention and Feed-Forward layers,
19+ A Single Transformer Decoder Block consisting of Multi-Head Attention and Feed-Forward layers,
1920 each with Pre-Normalization (RMSNorm) and Standard Residual Connections.
2021
2122 Args:
2223 config (TransformerConfig): Configuration object.
24+ attn_kwargs: (Dict, optional): Additional Arguments for the attention class passed from ``TransformerConfig.attn_class``.
25+ It is only used if ``TransformerConfig.attn_class`` is ``Type[nn.Module]``
26+ ffn_kwargs: (Dict, optional): Additional Arguments for the ffn class passed from ``TransformerConfig.ffn_class``.
27+ It is only used if ``TransformerConfig.ffn_class`` is ``Type[nn.Module]``
28+ norm_kwargs: (Dict, optional): Additional Arguments for the normalization class passed from ``TransformerConfig.norm_class``. It is always passed.
2329 layer_idx (int, optional): Index of this block (used for debugging/logging).
2430 """
2531
26- def __init__ (self , config , layer_idx : Optional [int ] = 0 ):
32+ def __init__ (
33+ self ,
34+ config ,
35+ attn_kwargs : Optional [Dict ] = {},
36+ ffn_kwargs : Optional [Dict ] = {},
37+ norm_kwargs : Optional [Dict ] = {},
38+ layer_idx : int = 0 ,
39+ ):
2740 super ().__init__ ()
2841 self .d_model , self .d_ff , self .n_heads , self .layer_idx = config .d_model , config .d_ff , config .n_heads , layer_idx
2942
30- if config .attn_type == "MHA" :
43+ if config .attn_class == "MHA" :
3144 self .attn = MHA (
3245 self .d_model ,
3346 self .n_heads ,
@@ -36,27 +49,107 @@ def __init__(self, config, layer_idx: Optional[int] = 0):
3649 qk_norm = config .attn_qk_norm ,
3750 layer_idx = layer_idx ,
3851 rope_base = config .rope_base ,
52+ pos_encoding = config .pos_encoding ,
3953 max_seq_len = config .max_seq_len ,
4054 )
41- elif config .attn_type == "GQA" :
55+ elif config .attn_class == "GQA" :
4256 self .attn = GQA (
4357 self .d_model ,
4458 self .n_heads ,
45- config .n_kv_heads ,
59+ n_kv_heads = config .n_kv_heads ,
4660 dropout = config .attn_dropout ,
4761 attn_bias = config .attn_bias ,
4862 qk_norm = config .attn_qk_norm ,
4963 layer_idx = layer_idx ,
5064 rope_base = config .rope_base ,
65+ pos_encoding = config .pos_encoding ,
5166 max_seq_len = config .max_seq_len ,
5267 )
53- elif config .attn_type == "CrossAttention" :
54- raise ValueError (f"Under Development: { config .attn_type } " )
68+ elif config .attn_class == "CrossAttention" :
69+ raise ValueError (f"Under Development: { config .attn_class } " )
70+ elif check_type (config .attn_class ) == 0 :
71+ raise ValueError (f"Unknown attention type: { config .attn_class } " )
72+ elif check_type (config .attn_class ) == 1 :
73+ self .attn = config .attn_class (
74+ self .d_model ,
75+ self .n_heads ,
76+ config .attn_bias ,
77+ {
78+ "dropout" : config .attn_dropout ,
79+ "qk_norm" : config .attn_qk_norm ,
80+ "layer_idx" : layer_idx ,
81+ "rope_base" : config .rope_base ,
82+ "pos_encoding" : config .pos_encoding ,
83+ "max_seq_len" : config .max_seq_len ,
84+ },
85+ ** attn_kwargs ,
86+ )
5587 else :
56- raise ValueError (f"Unknown attention type: { config .attn_type } " )
88+ raise RuntimeError (
89+ "TransformerConfig.attn_class Should be str or Type[nn.Module] but found: {config.attn_class}"
90+ )
5791
58- self .ffn = SwiGLU (self .d_model , self .d_ff , bias = config .ffn_bias )
59- self .norm_attn , self .norm_ffn = nn .RMSNorm (self .d_model ), nn .RMSNorm (self .d_model )
92+ if config .ffn_class == "SwiGLU" :
93+ self .ffn = SwiGLU (self .d_model , self .d_ff , bias = config .ffn_bias )
94+ elif config .ffn_class == "MLP" :
95+ self .ffn = MLP (self .d_model , self .d_ff , bias = config .ffn_bias )
96+ elif config .ffn_class == "MoE" :
97+ raise ValueError (f"Under Development: { config .ffn_class } " )
98+ elif check_type (config .ffn_class ) == 0 :
99+ raise ValueError (f"Unknown ffn class: { config .ffn_class } " )
100+ elif check_type (config .ffn_class ) == 1 :
101+ self .ffn = config .ffn_class (self .d_model , self .d_ff , bias = config .ffn_bias , ** ffn_kwargs )
102+ else :
103+ raise RuntimeError (
104+ "TransformerConfig.ffn_class Should be str or Type[nn.Module] but found: {config.ffn_class}"
105+ )
106+
107+ # FIXME
108+ if config .norm_class == "rms_norm" :
109+ if config .norm_design == "pre_norm" or config .norm_design == "post_norm" :
110+ self .norm_attn , self .norm_ffn = (
111+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
112+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
113+ )
114+ elif config .norm_design == "both" :
115+ self .pre_norm_attn , self .pre_norm_ffn , self .post_norm_attn , self .post_norm_ffn = (
116+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
117+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
118+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
119+ nn .RMSNorm (self .d_model , ** norm_kwargs ),
120+ )
121+ elif config .norm_class == "layer_norm" :
122+ if config .norm_design == "pre_norm" or config .norm_design == "post_norm" :
123+ self .norm_attn , self .norm_ffn = (
124+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
125+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
126+ )
127+ elif config .norm_design == "both" :
128+ self .pre_norm_attn , self .pre_norm_ffn , self .post_norm_attn , self .post_norm_ffn = (
129+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
130+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
131+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
132+ nn .LayerNorm (self .d_model , ** norm_kwargs ),
133+ )
134+ elif check_type (config .norm_class ) == 0 :
135+ raise ValueError (f"Unknown normalization class: { config .norm_class } " )
136+ elif check_type (config .norm_class ) == 1 :
137+ if config .norm_design == "pre_norm" or config .norm_design == "post_norm" :
138+ self .norm_attn , self .norm_ffn = (
139+ config .norm_class (self .d_model , ** norm_kwargs ),
140+ config .norm_class (self .d_model , ** norm_kwargs ),
141+ )
142+ elif config .norm_design == "both" :
143+ self .pre_norm_attn , self .pre_norm_ffn , self .post_norm_attn , self .post_norm_ffn = (
144+ config .norm_class (self .d_model , ** norm_kwargs ),
145+ config .norm_class (self .d_model , ** norm_kwargs ),
146+ config .norm_class (self .d_model , ** norm_kwargs ),
147+ config .norm_class (self .d_model , ** norm_kwargs ),
148+ )
149+ else :
150+ raise RuntimeError (
151+ "TransformerConfig.norm_class Should be str or Type[nn.Module] but found: {config.norm_class}"
152+ )
60153
61154 def forward (
62155 self ,
@@ -105,13 +198,27 @@ class Transformer(PreTrainedModel, GenerationMixin):
105198 config_class = TransformerConfig
106199 base_model_prefix = "transformer"
107200
108- def __init__ (self , config ):
201+ supports_gradient_checkpointing = True
202+ _supports_flash_attn = True
203+ _supports_sdpa = True
204+
205+ input_modalities = "text" # Will add "image" for v0.4.0
206+
207+ def __init__ (
208+ self ,
209+ config ,
210+ attn_kwargs : Dict = {},
211+ ffn_kwargs : Dict = {},
212+ norm_kwargs : Dict = {},
213+ ):
109214 super ().__init__ (config )
110215 self .config = config
111216 self .d_model = config .d_model
112217
113218 self .emb = nn .Embedding (config .vocab_size , config .d_model )
114- self .blocks = nn .ModuleList ([TransformerBlock (config , i ) for i in range (config .n_layer )])
219+ self .blocks = nn .ModuleList (
220+ [TransformerBlock (config , attn_kwargs , ffn_kwargs , norm_kwargs , i ) for i in range (config .n_layer )]
221+ )
115222 self .norm_out = nn .RMSNorm (config .d_model )
116223 self .lm_head = nn .Linear (config .d_model , config .vocab_size , bias = config .lm_head_bias )
117224
@@ -145,7 +252,7 @@ def forward(
145252 False ,
146253 ),
147254 return_states : Optional [bool ] = False ,
148- loss_kwargs : Dict = None ,
255+ loss_kwargs : Dict = {} ,
149256 ** kwargs : Dict ,
150257 ) -> CausalLMOutput :
151258 """
@@ -196,6 +303,13 @@ def forward(
196303 loss = loss , logits = logits , hidden_states = (input_embs , hidden_states ) if return_states else None
197304 )
198305
306+ def get_input_embeddings (self ) -> nn .Embedding :
307+ return self .emb
308+
309+ def set_input_embeddings (self , embeddings : nn .Embedding ):
310+ self .shared = new_embed
311+ self .encoder .set_input_embeddings (new_embeddings )
312+
199313 def get_num_params (self ) -> int :
200314 """Return the number of trainable parameters."""
201315 return sum (p .numel () for p in self .parameters () if p .requires_grad )
0 commit comments