Skip to content

Commit dcaf154

Browse files
committed
pre-release v0.3.0
1 parent 7feefc7 commit dcaf154

5 files changed

Lines changed: 162 additions & 28 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "transformer"
7-
version = "0.2.0"
7+
version = "0.3.0"
88
description = "A polished PyTorch implementation of the current State-Of-The-Art(SOTA) Transformer"
99
readme = "README.md"
1010
authors = [

transformer/attns.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
qk_norm: Optional[bool] = True,
4444
layer_idx: int = 0,
4545
rope_base: float = 10000.0,
46+
pos_encoding: str = "RoPE",
4647
max_seq_len: int = 1024,
4748
):
4849
super().__init__()
@@ -53,7 +54,12 @@ def __init__(
5354
self.qkv_proj = nn.Linear(self.d_model, self.d_model * 3, bias=attn_bias)
5455
self.out_proj = nn.Linear(self.d_model, self.d_model, bias=attn_bias)
5556

56-
self.rope = RoPE(max_seq_len, self.d_head, rope_base=rope_base)
57+
if pos_encoding == "RoPE":
58+
self.rope = RoPE(max_seq_len, self.d_head, rope_base=rope_base)
59+
elif pos_encoding == "AliBI":
60+
raise ValueError("Under Development")
61+
else:
62+
raise ValueError("Not implemented")
5763
self.scale = self.d_head**-0.5
5864

5965
self.dropout = dropout if dropout is not None or dropout != 0.0 else None
@@ -198,6 +204,7 @@ def __init__(
198204
qk_norm: Optional[bool] = True,
199205
layer_idx: int = 0,
200206
rope_base: float = 10000.0,
207+
pos_encoding: str = "RoPE",
201208
max_seq_len: int = 1024,
202209
):
203210
super().__init__()

transformer/config.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import math
2-
from typing import Dict, Optional, Union, Type
2+
from typing import Dict, Optional, Type, Union
33

44
from transformers import PretrainedConfig
55

6+
import torch
7+
import torch.nn as nn
8+
69

710
class TransformerConfig(PretrainedConfig):
811
r"""
@@ -25,11 +28,17 @@ class TransformerConfig(PretrainedConfig):
2528
2629
max_seq_len (int): Maximum sequence length for positional embeddings.
2730
31+
ffn_class (Union[Type[nn.Module], str], optional): Feed-Forward Network class or type.
32+
- If ``str``, one of ``SwiGLU``, ``MLP``.
33+
- If ``Type[nn.Module]`` then will beinstantiated inside the model.
34+
Should have the same API as ``SwiGLU`` and ``MLP``.
35+
Default ``SwiGLU``
36+
2837
attn_bias (bool, optional): Whether to use bias in attention Linear Projections. Default: ``False``
2938
3039
attn_qk_norm (bool, optional): Whether to apply Normalization to Queries and Keys before the Attention Computation. Default: ``True``
3140
32-
norm_type (Union[Type[nn.Module], str], optional): Normalization class or type.
41+
norm_class (Union[Type[nn.Module], str], optional): Normalization class or type.
3342
- If ``str``, one of ``rms_norm`` or ``layer_norm``.
3443
- If ``Type[nn.Module]`` then will be instantiated inside the model.
3544
Should have the same API as a torch Normalization Layer.
@@ -47,9 +56,9 @@ class TransformerConfig(PretrainedConfig):
4756
4857
rope_base (float, optional): Base for the RoPE frequency computation. Default: ``10000.0``
4958
50-
attn_type (Union[Type[nn.Module], str], optional): Attention class or type.
59+
attn_class (Union[Type[nn.Module], str], optional): Attention class or type.
5160
- If ``str``, one of ``MHA``, ``GQA``, ``CrossAttention``. For ``GQA``, also specify `n_kv_heads`.
52-
- If ``Type`` then will beinstantiated inside the model.
61+
- If ``Type[nn.Module]`` then will beinstantiated inside the model.
5362
Should have the same API as ``transformer.attn.MHA``.
5463
Default ``MHA``
5564
@@ -75,8 +84,9 @@ def __init__(
7584
vocab_size: int = 50000,
7685
d_ff: Optional[int] = None,
7786
norm_design: str = "pre_norm",
78-
norm_type: Union[Type[nn.Module], str] = "rms_norm",
79-
attn_type: Union[Type[nn.Module], str] = "MHA",
87+
norm_class: Union[Type[nn.Module], str] = "rms_norm",
88+
ffn_class: Union[Type[nn.Module], str] = "SwiGLU",
89+
attn_class: Union[Type[nn.Module], str] = "MHA",
8090
attn_bias: Optional[bool] = False,
8191
ffn_bias: bool = True,
8292
lm_head_bias: bool = False,
@@ -94,11 +104,14 @@ def __init__(
94104
self.n_layer = n_layers
95105
self.d_model = d_model
96106
self.n_heads = n_heads
97-
self.n_kv_heads = n_kv_heads if attn_type == "GQA" else n_heads
107+
self.n_kv_heads = n_kv_heads if attn_class == "GQA" else n_heads
98108
self.vocab_size = vocab_size
99109

100-
self.attn_type = attn_type
101-
self.norm_type = norm_type
110+
self.attn_class = attn_class
111+
self.ffn_class = ffn_class
112+
self.norm_class = norm_class
113+
114+
self.norm_design = norm_design
102115

103116
self.d_ff = d_ff if d_ff is not None else math.ceil(d_model * 2.666)
104117

transformer/transformer.py

Lines changed: 130 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,46 @@
11
import math
2-
from typing import Any, Dict, Optional, Union, Tuple
2+
from typing import Any, Dict, Optional, Tuple, Union
33

44
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
77
from transformers import GenerationMixin, PretrainedConfig, PreTrainedModel
88
from transformers.modeling_outputs import CausalLMOutput
9-
9+
from transformers.modeling_layers import GradientCheckpointingLayer
1010
from .attns import MHA
1111
from .config import TransformerConfig
1212
from .ffn import SwiGLU
1313
from .pos import RoPE
14+
from .utils import check_type
1415

1516

16-
class TransformerBlock(nn.Module):
17+
class TransformerBlock(GradientCheckpointingLayer):
1718
"""
18-
A Single Decoder Transformer Block consisting of Multi-Head Attention and Feed-Forward layers,
19+
A Single Transformer Decoder Block consisting of Multi-Head Attention and Feed-Forward layers,
1920
each with Pre-Normalization (RMSNorm) and Standard Residual Connections.
2021
2122
Args:
2223
config (TransformerConfig): Configuration object.
24+
attn_kwargs: (Dict, optional): Additional Arguments for the attention class passed from ``TransformerConfig.attn_class``.
25+
It is only used if ``TransformerConfig.attn_class`` is ``Type[nn.Module]``
26+
ffn_kwargs: (Dict, optional): Additional Arguments for the ffn class passed from ``TransformerConfig.ffn_class``.
27+
It is only used if ``TransformerConfig.ffn_class`` is ``Type[nn.Module]``
28+
norm_kwargs: (Dict, optional): Additional Arguments for the normalization class passed from ``TransformerConfig.norm_class``. It is always passed.
2329
layer_idx (int, optional): Index of this block (used for debugging/logging).
2430
"""
2531

26-
def __init__(self, config, layer_idx: Optional[int] = 0):
32+
def __init__(
33+
self,
34+
config,
35+
attn_kwargs: Optional[Dict] = {},
36+
ffn_kwargs: Optional[Dict] = {},
37+
norm_kwargs: Optional[Dict] = {},
38+
layer_idx: int = 0,
39+
):
2740
super().__init__()
2841
self.d_model, self.d_ff, self.n_heads, self.layer_idx = config.d_model, config.d_ff, config.n_heads, layer_idx
2942

30-
if config.attn_type == "MHA":
43+
if config.attn_class == "MHA":
3144
self.attn = MHA(
3245
self.d_model,
3346
self.n_heads,
@@ -36,27 +49,107 @@ def __init__(self, config, layer_idx: Optional[int] = 0):
3649
qk_norm=config.attn_qk_norm,
3750
layer_idx=layer_idx,
3851
rope_base=config.rope_base,
52+
pos_encoding=config.pos_encoding,
3953
max_seq_len=config.max_seq_len,
4054
)
41-
elif config.attn_type == "GQA":
55+
elif config.attn_class == "GQA":
4256
self.attn = GQA(
4357
self.d_model,
4458
self.n_heads,
45-
config.n_kv_heads,
59+
n_kv_heads=config.n_kv_heads,
4660
dropout=config.attn_dropout,
4761
attn_bias=config.attn_bias,
4862
qk_norm=config.attn_qk_norm,
4963
layer_idx=layer_idx,
5064
rope_base=config.rope_base,
65+
pos_encoding=config.pos_encoding,
5166
max_seq_len=config.max_seq_len,
5267
)
53-
elif config.attn_type == "CrossAttention":
54-
raise ValueError(f"Under Development: {config.attn_type}")
68+
elif config.attn_class == "CrossAttention":
69+
raise ValueError(f"Under Development: {config.attn_class}")
70+
elif check_type(config.attn_class) == 0:
71+
raise ValueError(f"Unknown attention type: {config.attn_class}")
72+
elif check_type(config.attn_class) == 1:
73+
self.attn = config.attn_class(
74+
self.d_model,
75+
self.n_heads,
76+
config.attn_bias,
77+
{
78+
"dropout": config.attn_dropout,
79+
"qk_norm": config.attn_qk_norm,
80+
"layer_idx": layer_idx,
81+
"rope_base": config.rope_base,
82+
"pos_encoding": config.pos_encoding,
83+
"max_seq_len": config.max_seq_len,
84+
},
85+
**attn_kwargs,
86+
)
5587
else:
56-
raise ValueError(f"Unknown attention type: {config.attn_type}")
88+
raise RuntimeError(
89+
"TransformerConfig.attn_class Should be str or Type[nn.Module] but found: {config.attn_class}"
90+
)
5791

58-
self.ffn = SwiGLU(self.d_model, self.d_ff, bias=config.ffn_bias)
59-
self.norm_attn, self.norm_ffn = nn.RMSNorm(self.d_model), nn.RMSNorm(self.d_model)
92+
if config.ffn_class == "SwiGLU":
93+
self.ffn = SwiGLU(self.d_model, self.d_ff, bias=config.ffn_bias)
94+
elif config.ffn_class == "MLP":
95+
self.ffn = MLP(self.d_model, self.d_ff, bias=config.ffn_bias)
96+
elif config.ffn_class == "MoE":
97+
raise ValueError(f"Under Development: {config.ffn_class}")
98+
elif check_type(config.ffn_class) == 0:
99+
raise ValueError(f"Unknown ffn class: {config.ffn_class}")
100+
elif check_type(config.ffn_class) == 1:
101+
self.ffn = config.ffn_class(self.d_model, self.d_ff, bias=config.ffn_bias, **ffn_kwargs)
102+
else:
103+
raise RuntimeError(
104+
"TransformerConfig.ffn_class Should be str or Type[nn.Module] but found: {config.ffn_class}"
105+
)
106+
107+
# FIXME
108+
if config.norm_class == "rms_norm":
109+
if config.norm_design == "pre_norm" or config.norm_design == "post_norm":
110+
self.norm_attn, self.norm_ffn = (
111+
nn.RMSNorm(self.d_model, **norm_kwargs),
112+
nn.RMSNorm(self.d_model, **norm_kwargs),
113+
)
114+
elif config.norm_design == "both":
115+
self.pre_norm_attn, self.pre_norm_ffn, self.post_norm_attn, self.post_norm_ffn = (
116+
nn.RMSNorm(self.d_model, **norm_kwargs),
117+
nn.RMSNorm(self.d_model, **norm_kwargs),
118+
nn.RMSNorm(self.d_model, **norm_kwargs),
119+
nn.RMSNorm(self.d_model, **norm_kwargs),
120+
)
121+
elif config.norm_class == "layer_norm":
122+
if config.norm_design == "pre_norm" or config.norm_design == "post_norm":
123+
self.norm_attn, self.norm_ffn = (
124+
nn.LayerNorm(self.d_model, **norm_kwargs),
125+
nn.LayerNorm(self.d_model, **norm_kwargs),
126+
)
127+
elif config.norm_design == "both":
128+
self.pre_norm_attn, self.pre_norm_ffn, self.post_norm_attn, self.post_norm_ffn = (
129+
nn.LayerNorm(self.d_model, **norm_kwargs),
130+
nn.LayerNorm(self.d_model, **norm_kwargs),
131+
nn.LayerNorm(self.d_model, **norm_kwargs),
132+
nn.LayerNorm(self.d_model, **norm_kwargs),
133+
)
134+
elif check_type(config.norm_class) == 0:
135+
raise ValueError(f"Unknown normalization class: {config.norm_class}")
136+
elif check_type(config.norm_class) == 1:
137+
if config.norm_design == "pre_norm" or config.norm_design == "post_norm":
138+
self.norm_attn, self.norm_ffn = (
139+
config.norm_class(self.d_model, **norm_kwargs),
140+
config.norm_class(self.d_model, **norm_kwargs),
141+
)
142+
elif config.norm_design == "both":
143+
self.pre_norm_attn, self.pre_norm_ffn, self.post_norm_attn, self.post_norm_ffn = (
144+
config.norm_class(self.d_model, **norm_kwargs),
145+
config.norm_class(self.d_model, **norm_kwargs),
146+
config.norm_class(self.d_model, **norm_kwargs),
147+
config.norm_class(self.d_model, **norm_kwargs),
148+
)
149+
else:
150+
raise RuntimeError(
151+
"TransformerConfig.norm_class Should be str or Type[nn.Module] but found: {config.norm_class}"
152+
)
60153

61154
def forward(
62155
self,
@@ -105,13 +198,27 @@ class Transformer(PreTrainedModel, GenerationMixin):
105198
config_class = TransformerConfig
106199
base_model_prefix = "transformer"
107200

108-
def __init__(self, config):
201+
supports_gradient_checkpointing = True
202+
_supports_flash_attn = True
203+
_supports_sdpa = True
204+
205+
input_modalities = "text" # Will add "image" for v0.4.0
206+
207+
def __init__(
208+
self,
209+
config,
210+
attn_kwargs: Dict = {},
211+
ffn_kwargs: Dict = {},
212+
norm_kwargs: Dict = {},
213+
):
109214
super().__init__(config)
110215
self.config = config
111216
self.d_model = config.d_model
112217

113218
self.emb = nn.Embedding(config.vocab_size, config.d_model)
114-
self.blocks = nn.ModuleList([TransformerBlock(config, i) for i in range(config.n_layer)])
219+
self.blocks = nn.ModuleList(
220+
[TransformerBlock(config, attn_kwargs, ffn_kwargs, norm_kwargs, i) for i in range(config.n_layer)]
221+
)
115222
self.norm_out = nn.RMSNorm(config.d_model)
116223
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=config.lm_head_bias)
117224

@@ -145,7 +252,7 @@ def forward(
145252
False,
146253
),
147254
return_states: Optional[bool] = False,
148-
loss_kwargs: Dict = None,
255+
loss_kwargs: Dict = {},
149256
**kwargs: Dict,
150257
) -> CausalLMOutput:
151258
"""
@@ -196,6 +303,13 @@ def forward(
196303
loss=loss, logits=logits, hidden_states=(input_embs, hidden_states) if return_states else None
197304
)
198305

306+
def get_input_embeddings(self) -> nn.Embedding:
307+
return self.emb
308+
309+
def set_input_embeddings(self, embeddings: nn.Embedding):
310+
self.shared = new_embed
311+
self.encoder.set_input_embeddings(new_embeddings)
312+
199313
def get_num_params(self) -> int:
200314
"""Return the number of trainable parameters."""
201315
return sum(p.numel() for p in self.parameters() if p.requires_grad)

transformer/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import os
33
import random
44
import sys
5-
65
from typing import Dict, Optional, Tuple, Type, Union
76

87
import torch
98
import torch.nn as nn
109
import torch.nn.functional as F
1110

11+
1212
def check_type(x: Union[Type[nn.Module], str]):
1313
if isinstance(x, str):
1414
return 0

0 commit comments

Comments
 (0)