Skip to content

Commit 376c94b

Browse files
committed
.
1 parent 966db65 commit 376c94b

37 files changed

Lines changed: 1328 additions & 1158 deletions

README.md

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
[![Stars](https://img.shields.io/github/stars/lof310/transformer)](#)
88
[![Downloads](https://img.shields.io/github/downloads/lof310/transformer/total)](https://github.com/lof310/transformer/releases)
99

10-
A polished PyTorch implementation of the current State-Of-The-Art(SOTA) Transformer. Designed for clarity, reproducibility, and interoperability with HuggingFace Transformers, this repository provides a robust baseline for research and engineering being fully configurable. The codebase emphasizes readable, well-documented components so you can iterate on attention mechanisms, Feed-Forward, Attention and Normalization blocks and other architectural variants with minimal friction.
10+
_A polished **PyTorch implementation** of the current **State-Of-The-Art(SOTA) Transformer**. Designed for clarity, reproducibility, and interoperability with **HuggingFace Transformers**, this repository provides a robust baseline for **Research** and **Engineering** being **Fully Configurable**. The codebase emphasizes **readable and well-documented components** so you can iterate on **Feed-Forward**, **Attention** and **Normalization** blocks and other **architectural variants** with minimal friction._
1111

1212
## Features
1313
- **Fully Configurable** architecture (layers, heads, model dimensions, dropout, etc.)
14-
- HuggingFace-compatible API alignment.
15-
- Compact and easily extensible design for rapid prototyping and research experiments.
16-
- Clear, well-documented modules to facilitate experimentation with attention, FFNs, etc.
14+
- **HuggingFace-compatible** API alignment.
15+
- **Compact and easily extensible** design for rapid prototyping and research experiments.
16+
- **Clear, well-documented modules** to facilitate experimentation with attention, FFNs, etc.
1717

1818
## Download the code
1919
```bash
@@ -46,7 +46,7 @@ config = TransformerConfig(
4646
n_layers = 12,
4747
n_heads: int = 32,
4848
d_model: int = 1536,
49-
qk_norm: bool = False,
49+
attn_qk_norm: bool = False,
5050
tied_weights: bool = False,
5151
seq_len: int = 1024,
5252
max_seq_len: int = 4096,
@@ -69,18 +69,25 @@ from transformer import TransformerConfig
6969

7070
TransformerConfig(
7171
n_layers = 12,
72-
d_model int = 1536,
72+
d_model = 1536,
7373
n_heads = 32,
74-
n_kv_heads = None, # GQA Disabled
75-
vocab_size int = 50000,
76-
d_ff = None, # Choosen Automatically: math.ceil(d_model * 2.666)
77-
attn_type = "MHA",
74+
n_kv_heads = None, # QKA Disabled
75+
vocab_size = 50000,
76+
d_ff = None, # Choosen Automatically, ratio 8/3=2.666
77+
norm_design = "pre_norm",
78+
norm_class = "rms_norm",
79+
ffn_class = "SwiGLU",
80+
attn_class = "MHA",
81+
block_class = None, # transformer.TransformerBlock
7882
attn_bias = False,
7983
ffn_bias = True,
80-
attn_qk_norm = True,
8184
lm_head_bias = False,
85+
attn_qk_norm = True,
86+
attn_dropout = 0.0,
8287
tied_weights = False,
8388
seq_len = 1024,
89+
pos_encoding = "RoPE",
90+
rope_base = 10000.0,
8491
max_seq_len = 4096
8592
)
8693
```

docs/build/doctrees/api.doctree

26.1 KB
Binary file not shown.
5.33 KB
Binary file not shown.
-2.59 KB
Binary file not shown.

docs/build/doctrees/guide.doctree

6.18 KB
Binary file not shown.

docs/build/doctrees/index.doctree

0 Bytes
Binary file not shown.

docs/build/html/_modules/transformer/attns.html

Lines changed: 147 additions & 110 deletions
Large diffs are not rendered by default.

docs/build/html/_modules/transformer/config.html

Lines changed: 78 additions & 46 deletions
Large diffs are not rendered by default.

docs/build/html/_modules/transformer/ffn.html

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -189,25 +189,20 @@
189189
<input type="hidden" name="area" value="default">
190190
</form>
191191
<div id="searchbox"></div><div class="sidebar-scroll"><div class="sidebar-tree">
192-
<p class="caption" role="heading"><span class="caption-text">Getting Started</span></p>
193-
<ul>
192+
<ul>
194193
<li class="toctree-l1"><a class="reference internal" href="../../installation.html">Installation</a></li>
195194
<li class="toctree-l1"><a class="reference internal" href="../../quickstart.html">Quick Start</a></li>
196195
</ul>
197-
<p class="caption" role="heading"><span class="caption-text">Guide</span></p>
198196
<ul>
199197
<li class="toctree-l1"><a class="reference internal" href="../../guide.html">Transformer: A PyTorch SOTA Transformer Implementation</a></li>
200198
<li class="toctree-l1"><a class="reference internal" href="../../guide.html#configuration">Configuration</a></li>
201199
</ul>
202-
<p class="caption" role="heading"><span class="caption-text">API Reference</span></p>
203200
<ul>
204201
<li class="toctree-l1"><a class="reference internal" href="../../api.html">API Reference</a></li>
205202
</ul>
206-
<p class="caption" role="heading"><span class="caption-text">Usage Examples</span></p>
207203
<ul>
208204
<li class="toctree-l1"><a class="reference internal" href="../../examples.html">Usage Examples</a></li>
209205
</ul>
210-
<p class="caption" role="heading"><span class="caption-text">Project Info</span></p>
211206
<ul>
212207
<li class="toctree-l1"><a class="reference internal" href="../../contributing.html">Contributing</a></li>
213208
</ul>
@@ -243,7 +238,7 @@
243238
</div>
244239
<article role="main" id="furo-main-content">
245240
<h1>Source code for transformer.ffn</h1><div class="highlight"><pre>
246-
<span></span><span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Union</span>
241+
<span></span><span class="kn">from</span><span class="w"> </span><span class="nn">typing</span><span class="w"> </span><span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Type</span><span class="p">,</span> <span class="n">Union</span>
247242

248243
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
249244
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
@@ -256,10 +251,14 @@ <h1>Source code for transformer.ffn</h1><div class="highlight"><pre>
256251
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
257252
<span class="sd"> SwiGLU feed-forward module</span>
258253

259-
<span class="sd"> Args:</span>
260-
<span class="sd"> d_model (int): Model dimension.</span>
261-
<span class="sd"> d_ff (int): Intermediate dimension (should be even, as it&#39;s split into two halves).</span>
262-
<span class="sd"> bias (bool, optional): Whether to use bias in linear layers. Default: ``True``</span>
254+
<span class="sd"> :param d_model: Model dimension.</span>
255+
<span class="sd"> :type d_model: int</span>
256+
257+
<span class="sd"> :param d_ff: Intermediate dimension (should be even, as it&#39;s split into two halves).</span>
258+
<span class="sd"> :type d_ff: int</span>
259+
260+
<span class="sd"> :param bias: Whether to use bias in linear layers. Default: ``True``</span>
261+
<span class="sd"> :type bias: bool, optional</span>
263262
<span class="sd"> &quot;&quot;&quot;</span>
264263

265264
<div class="viewcode-block" id="SwiGLU.__init__">
@@ -278,13 +277,15 @@ <h1>Source code for transformer.ffn</h1><div class="highlight"><pre>
278277
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
279278
<span class="sd"> Forward pass of SwiGLU.</span>
280279

281-
<span class="sd"> Args:</span>
282-
<span class="sd"> x (torch.Tensor): Input tensor of shape :math:`(..., D)`</span>
283-
<span class="sd"> return_states (bool, optional): If True, return intermediate activations and input. Default: ``False``</span>
280+
<span class="sd"> :param x: Input tensor of shape :math:`(..., D)`</span>
281+
<span class="sd"> :type x: torch.Tensor</span>
282+
283+
<span class="sd"> :param return_states: If True, return intermediate activations and input. Default: ``False``</span>
284+
<span class="sd"> :type return_states: bool, optional</span>
284285

285-
<span class="sd"> Returns:</span>
286-
<span class="sd"> Union[torch.Tensor, Dict]: Output tensor :math:`(..., D)` or dict with intermediates states</span>
287-
<span class="sd"> containing the keys: &quot;output&quot;, &quot;y1&quot;, &quot;y2&quot; and &quot;input&quot;.</span>
286+
<span class="sd"> :return: Output tensor :math:`(..., D)` or dict with intermediates states</span>
287+
<span class="sd"> containing the keys: &quot;output&quot;, &quot;y1&quot;, &quot;y2&quot; and &quot;input&quot;.</span>
288+
<span class="sd"> :rtype: Union[torch.Tensor, Dict]</span>
288289
<span class="sd"> &quot;&quot;&quot;</span>
289290
<span class="n">y1</span><span class="p">,</span> <span class="n">y2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">W1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
290291
<span class="k">if</span> <span class="n">return_states</span><span class="p">:</span>
@@ -301,10 +302,14 @@ <h1>Source code for transformer.ffn</h1><div class="highlight"><pre>
301302
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
302303
<span class="sd"> Classic MLP with GELU activation (as used in the original Transformer).</span>
303304

304-
<span class="sd"> Args:</span>
305-
<span class="sd"> d_model (int): Model dimension.</span>
306-
<span class="sd"> d_ff (int): Intermediate dimension.</span>
307-
<span class="sd"> bias (bool, optional): Whether to use bias in linear layers. Default: ``True``</span>
305+
<span class="sd"> :param d_model: Model dimension.</span>
306+
<span class="sd"> :type d_model: int</span>
307+
308+
<span class="sd"> :param d_ff: Intermediate dimension.</span>
309+
<span class="sd"> :type d_ff: int</span>
310+
311+
<span class="sd"> :param bias: Whether to use bias in linear layers. Default: ``True``</span>
312+
<span class="sd"> :type bias: bool, optional</span>
308313
<span class="sd"> &quot;&quot;&quot;</span>
309314

310315
<div class="viewcode-block" id="MLP.__init__">
@@ -324,13 +329,15 @@ <h1>Source code for transformer.ffn</h1><div class="highlight"><pre>
324329
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
325330
<span class="sd"> Forward pass of MLP.</span>
326331

327-
<span class="sd"> Args:</span>
328-
<span class="sd"> x (torch.Tensor): Input tensor of shape :math:`(..., D)`</span>
329-
<span class="sd"> return_states (bool, optional): If True, return intermediate activations. Default: ``False``</span>
332+
<span class="sd"> :param x: Input tensor of shape :math:`(..., D)`</span>
333+
<span class="sd"> :type x: torch.Tensor</span>
334+
335+
<span class="sd"> :param return_states: If True, return intermediate activations. Default: ``False``</span>
336+
<span class="sd"> :type return_states: bool, optional</span>
330337

331-
<span class="sd"> Returns:</span>
332-
<span class="sd"> Union[torch.Tensor, Dict]: Output tensor :math:`(..., D)` or dict with intermediates states</span>
333-
<span class="sd"> containing the keys: &quot;output&quot;, &quot;h1&quot;, &quot;h2&quot; and &quot;input&quot;.</span>
338+
<span class="sd"> :return: Output tensor :math:`(..., D)` or dict with intermediates states</span>
339+
<span class="sd"> containing the keys: &quot;output&quot;, &quot;h1&quot;, &quot;h2&quot; and &quot;input&quot;.</span>
340+
<span class="sd"> :rtype: Union[torch.Tensor, Dict]</span>
334341
<span class="sd"> &quot;&quot;&quot;</span>
335342
<span class="k">if</span> <span class="n">return_states</span><span class="p">:</span>
336343
<span class="n">h1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">net</span><span class="p">[</span><span class="mi">0</span><span class="p">](</span><span class="n">x</span><span class="p">)</span>

0 commit comments

Comments
 (0)