-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathatom.xml
More file actions
487 lines (257 loc) · 289 KB
/
atom.xml
File metadata and controls
487 lines (257 loc) · 289 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
<title>不会魔法的小圆</title>
<subtitle>世界如此可爱</subtitle>
<link href="https://anti-entrophic.github.io/atom.xml" rel="self"/>
<link href="https://anti-entrophic.github.io/"/>
<updated>2026-04-20T06:36:41.406Z</updated>
<id>https://anti-entrophic.github.io/</id>
<author>
<name>不会魔法的小圆</name>
</author>
<generator uri="https://hexo.io/">Hexo</generator>
<entry>
<title>谱条件</title>
<link href="https://anti-entrophic.github.io/posts/10061.html"/>
<id>https://anti-entrophic.github.io/posts/10061.html</id>
<published>2026-04-20T06:29:20.000Z</published>
<updated>2026-04-20T06:36:41.406Z</updated>
<content type="html"><![CDATA[<h1 id="Introduction"><a href="#Introduction" class="headerlink" title="Introduction"></a>Introduction</h1><p>Greg Yang 大名鼎鼎的 Tensor Program,给出了模型如何高效 feature learning 的一个有效先验与相关理论实践。本篇 blog 介绍一个简化、或者说是更本质的版本,被概括为 Spectral Condition,谱条件。</p><h1 id="Spectral-Condition"><a href="#Spectral-Condition" class="headerlink" title="Spectral Condition"></a>Spectral Condition</h1><p>谱条件总共分为两个部分:1)如何设置初始化, 2)如何 scale 学习率</p><p>它们的目的分别是控制 $\lVert W \rVert$ 与 $\lVert \Delta W \rVert$ 为 $\Theta(1)$。控制在 $\Theta (1)$ 的先验其实很朴素,太大更新就会炸掉,太小则捕捉不到feature的信息,总之是为了保证模型进行适当的学习。</p><p>当然,这里还有很多问题,比如说范数怎么取,之类的,后面会逐步介绍。</p><h1 id="Sparse-Dense-Vector-amp-Natural-Norm"><a href="#Sparse-Dense-Vector-amp-Natural-Norm" class="headerlink" title="Sparse/Dense Vector & Natural Norm"></a>Sparse/Dense Vector & Natural Norm</h1><p>作者区分了两类向量:</p><h2 id="Sparse-vector"><a href="#Sparse-vector" class="headerlink" title="Sparse vector"></a>Sparse vector</h2><p>指的是,只有 $\Theta(1)$ 数量的分量非零,比如 one-hot 向量。它的 $\mathcal{l}_2$ 范数本身只有 $\Theta(1)$ 级别。</p><h2 id="Dense-vector"><a href="#Dense-vector" class="headerlink" title="Dense vector"></a>Dense vector</h2><p>有 $\Theta(m)$ 数量的分量非零,很多分量都在贡献长度。因此,$\mathcal{l}_2$ 范数满足</p><script type="math/tex; mode=display">\|v\|_2 \sim \sqrt{m}</script><h2 id="Natural-Norm"><a href="#Natural-Norm" class="headerlink" title="Natural Norm"></a>Natural Norm</h2><p>上述的 Dense Vector 就会有一个问题,它的 $\mathcal{l}_2$ 范数大小会随着网络宽度而增长,不能忠实地反映元素的数量级,而我们想要保持的什么 $\Delta W \sim \Theta(1)$ 这种性质都不应和网络宽度耦合在一起。</p><p>因此,作者定义了一个自然范数,用于修正这种偏差。即对于 Dense Vector 而言,它的自然 $\mathcal{l}_2$ 范数为:</p><script type="math/tex; mode=display">\|v\|_{\tilde{2}} := \frac{1}{\sqrt{m}} \|v\|_2 = \frac{1}{\sqrt{m}} \sqrt{\sum_{i=1}^m v_i^2} = \sqrt{\frac{1}{m}\sum_{i=1}^m v_i^2}</script><p>这其实就是 RMS norm,它只关心每个分量平均有多大。而对于 Sparse Vector 则不用修正</p><script type="math/tex; mode=display">\|v\|_{\tilde{2}} := \|v\|_2</script><h2 id="Natural-Spectral-Norm"><a href="#Natural-Spectral-Norm" class="headerlink" title="Natural Spectral Norm"></a>Natural Spectral Norm</h2><p>在定义了 Natural Norm 之后,我们自然可以诱导一个 Natural Spectral Norm</p><script type="math/tex; mode=display">\|A\|_{\tilde{\text{op}}} = \sup_{x \neq 0} \frac{\|Ax\|_{\tilde{2}}}{\|x\|_{\tilde{2}}}</script><p>对于输入输出均为 Dense 的算子而言,有:</p><script type="math/tex; mode=display">\|A\|_{\tilde{\text{op}}} = \frac{\sqrt{n}}{\sqrt{m}} \|A\|_{\text{op}}</script><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBQR5p1x_9zHTfZe4ksHRha2ctxA2seQAC5Q1rGzl6uEZ048wxLoHx4QEAAwIAA3cAAzsE.png" width="400px" /><p style="font-size: 10px;"> 这里原文疑似写反了......</p></center><p>而对于 embedding 这种,就是</p><script type="math/tex; mode=display">\|A\|_{\tilde{\text{op}}} = \frac{1}{\sqrt{m}} \|A\|_{\text{op}}</script><h1 id="Initialization"><a href="#Initialization" class="headerlink" title="Initialization"></a>Initialization</h1><p>我们希望矩阵的自然谱范数满足:</p><script type="math/tex; mode=display">\|W\|_{\tilde{\text{op}}} = \Theta(1)</script><h2 id="Linear-amp-lm-head"><a href="#Linear-amp-lm-head" class="headerlink" title="Linear & lm_head"></a>Linear & lm_head</h2><p>该条件等价于</p><script type="math/tex; mode=display">\|W\|_{\text{op}} = \Theta \big(\sqrt{\frac{m}{n}}\big)</script><p>对于一个正态分布初始化的 Gaussian 矩阵 $W \in \mathbb{R}^{m \times n}$,若元素标准差为 $\sigma$,通常有:</p><script type="math/tex; mode=display">\|W\|_{\text{op}} \approx \sigma(\sqrt{m} + \sqrt{n})</script><p>这个结论我们可以考虑随机矩阵中的 <a href="https://anti-entrophic.github.io/posts/10050.html#:~:text=%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0%E7%AD%89%E7%AD%89%E3%80%82-,Marchenko%2DPastur%20%E5%AE%9A%E5%BE%8B,-%E6%8F%8F%E8%BF%B0%E7%9A%84%E6%98%AF%E5%8D%8F">Marchenko-Pastur 定律</a>。对于 $W$ 这样一个均值为 $0$,方差为 $\sigma^2$ 的 i.i.d. 随机矩阵,在 $m, n \rightarrow \infty$ 且 $\gamma = \frac{n}{m}$ 收敛到一个常数的极限下,它的样本协方差矩阵</p><script type="math/tex; mode=display">S = \frac{1}{m}X^TX \in \mathbb{R}^{n \times n}</script><p>的特征值分布将依概率收敛于 Marchenko-Pastur 分布,支撑集是:</p><script type="math/tex; mode=display">\lambda_\pm = \sigma^2(1 \pm \sqrt{\gamma})^2</script><p>而我们知道奇异值的平方就是 $W^TW$ 的特征值,因此</p><script type="math/tex; mode=display">\begin{aligned}\|W\|_{\text{op}}^2 &= \lambda_{\text{max}}(W^TW) \approx m \cdot \sigma^2(1 + \sqrt{\frac{n}{m}})^2 \\\|W\|_{\text{op}} &\approx \sigma (\sqrt{m} + \sqrt{n})\end{aligned}</script><p>实在是一个非常漂亮的结论。当然还有其它解释方法,不过我觉得随机矩阵理论是一个很本质的解释。</p><p>而要让 $\sigma(\sqrt{m} + \sqrt{n}) \sim \sqrt{\frac{m}{n}}$,我们有:</p><script type="math/tex; mode=display">\sigma \sim \frac{1}{\sqrt{n}} \frac{\sqrt{m}}{\sqrt{m} + \sqrt{n}}</script><p>这里本来也不是精确的数值关系,原论文干脆搞了一个近似:</p><script type="math/tex; mode=display">\sigma \sim \frac{1}{\sqrt{n}} \min\{1, \sqrt{\frac{m}{n}}\}</script><p>大概差一个常数数值,用起来可能差不多吧,很难说。</p><p>对于 lm_head 层,我们有</p><script type="math/tex; mode=display">\sigma \sim \frac{1}{\sqrt{d}} \frac{\sqrt{V}}{\sqrt{V} + \sqrt{d}}</script><p>而我们知道 $V \gg d$,所以对于 lm_head,有 $\sigma \sim \frac{1}{\sqrt{d}}$</p><h2 id="Embedding"><a href="#Embedding" class="headerlink" title="Embedding"></a>Embedding</h2><p>对于 Embedding 层,该条件等价于</p><script type="math/tex; mode=display">\|W\|_{\text{op}} = \Theta (\sqrt{m}) \Rightarrow \sigma \sim \frac{\sqrt{d}}{\sqrt{d} + \sqrt{V}}</script><p>然而,这个推导是有问题的,原因在于对于自然谱范数</p><script type="math/tex; mode=display">\|W\|_{\tilde{\text{op}}} = \sup_{\|x\|=1} \|Wx\|_{\tilde{2}}</script><p>事实上满足这个 $\lVert x \rVert=1$ 的条件的集合,包含了大量模型永远不会遇到的方向,Embedding 层实际上只会遇到 one-hot 向量而已。所以实际上,我们只需考虑</p><script type="math/tex; mode=display">\|We_i\|_{\tilde{2}} = \frac{1}{\sqrt{d}} \|w_i\|_2</script><p>这里的 $w_i$ 其实就是 Embedding 矩阵中的一列,其中每个元素也有 $w_{ji} \sim \mathcal{N}(0, \sigma^2)$,因此有</p><script type="math/tex; mode=display">\mathbb{E}[\|w_i\|_2^2] = \sum_{j=1}^d \mathbb{E}[w_{ji}^2] = d\sigma^2</script><p>所以</p><script type="math/tex; mode=display">\|We_i\|_{\tilde{2}} \approx \sigma = \Theta(1)</script><h2 id="Code"><a href="#Code" class="headerlink" title="Code"></a>Code</h2><p>实践下来还有一个对于 $\text{c_proj}$ 层的缩放,主要是考虑到残差的叠加问题。还有 embedding 层也是,有说不要把 $\sigma$ 设为 $\Theta(1)$,而是更小一点保证稳定。Anyway,实践牵扯的东西很多,大家可以在自己的setting下试一下看看是否需要这些 trick。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> math</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torch.nn <span class="keyword">as</span> nn</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">_spectral_linear_std</span>(<span class="params">fan_out: <span class="built_in">int</span>, fan_in: <span class="built_in">int</span></span>) -> <span class="built_in">float</span>:</span><br><span class="line"> <span class="string">""" This is a practical approximation from paper"""</span></span><br><span class="line"> <span class="keyword">return</span> (<span class="number">1.0</span> / math.sqrt(fan_in)) * <span class="built_in">min</span>(<span class="number">1.0</span>, math.sqrt(fan_out / fan_in))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">_spectral_linear_std_noapprox</span>(<span class="params">fan_out: <span class="built_in">int</span>, fan_in: <span class="built_in">int</span></span>) -> <span class="built_in">float</span>:</span><br><span class="line"> <span class="string">""" This is the theoretical spectral bound """</span></span><br><span class="line"> <span class="keyword">return</span> (<span class="number">1.0</span> / math.sqrt(fan_in)) * (math.sqrt(fan_out) / (math.sqrt(fan_in) + math.sqrt(fan_out)))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">_spectral_embedding_std_stable</span>(<span class="params">n_embd: <span class="built_in">int</span></span>) -> <span class="built_in">float</span>:</span><br><span class="line"> <span class="string">""" This is a practical approximation to keep stable training """</span></span><br><span class="line"> <span class="keyword">return</span> <span class="number">1.0</span> / math.sqrt(n_embd)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">_init_spectral_module</span>(<span class="params">module_name: <span class="built_in">str</span>, module: nn.Module, config</span>) -> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, nn.Linear):</span><br><span class="line"> fan_out = module.weight.size(<span class="number">0</span>)</span><br><span class="line"> fan_in = module.weight.size(<span class="number">1</span>)</span><br><span class="line"> <span class="keyword">if</span> config.spectral_linear_init_std == <span class="string">"approx"</span>:</span><br><span class="line"> std = _spectral_linear_std(fan_out, fan_in)</span><br><span class="line"> <span class="keyword">elif</span> config.spectral_linear_init_std == <span class="string">"noapprox"</span>:</span><br><span class="line"> std = _spectral_linear_std_noapprox(fan_out, fan_in)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">f"Invalid spectral linear init std: <span class="subst">{config.spectral_linear_init_std}</span>"</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> module_name.endswith(<span class="string">"c_proj"</span>):</span><br><span class="line"> std = std / math.sqrt(<span class="number">2</span> * config.n_layer)</span><br><span class="line"> torch.nn.init.normal_(</span><br><span class="line"> module.weight,</span><br><span class="line"> mean=<span class="number">0.0</span>,</span><br><span class="line"> std=std,</span><br><span class="line"> )</span><br><span class="line"> <span class="keyword">if</span> module.bias <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> torch.nn.init.zeros_(module.bias)</span><br><span class="line"> <span class="keyword">elif</span> <span class="built_in">isinstance</span>(module, nn.Embedding):</span><br><span class="line"> <span class="keyword">if</span> config.spectral_embedding_init_std == <span class="string">"stable"</span>:</span><br><span class="line"> std = _spectral_embedding_std_stable(config.n_embd)</span><br><span class="line"> <span class="keyword">elif</span> config.spectral_embedding_init_std == <span class="string">"exact"</span>:</span><br><span class="line"> std = <span class="number">1.0</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">raise</span> ValueError(<span class="string">f"Invalid spectral embedding init std: <span class="subst">{config.spectral_embedding_init_std}</span>"</span>)</span><br><span class="line"> torch.nn.init.normal_(</span><br><span class="line"> module.weight,</span><br><span class="line"> mean=<span class="number">0.0</span>,</span><br><span class="line"> std=std,</span><br><span class="line"> )</span><br></pre></td></tr></table></figure><h1 id="Learning-rate-scaling"><a href="#Learning-rate-scaling" class="headerlink" title="Learning rate scaling"></a>Learning rate scaling</h1><p>TODO</p><h2 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h2><p>我自己的实验框架应用了 Spectral Condition 后确实拿到了收益;同时也踩了一些坑,比如说 mHC 引入的那些参数,不可一股脑应用 mup。自己试一遍会理解更深刻。</p><p>对于一般的线性层,做的是 feature $\rightarrow$ feature 的操作,那自然谱范数就是一个最合适的选择。我觉得一个很重要的思想是,我们应当把矩阵当成一个算子而不是一系列参数来看待,控制算子强度的 谱范数 会比 F范数 这些信息更核心。</p><p>对于像 embedding 这种层,我们需要注意其谱范数定义域的差异,即此时输入只有 $e_i$ 而非任意的 $\lVert x \rVert = 1$。不过我其实有点好奇,lmhead 应该和中间的 linear 享受同样的策略吗?因为它的输出其实是某种概率分布,而不是一个feature。控制它的谱范数似乎意义不大,只是会影响 softmax 的温度系数而已,也许这里有更好的控制策略偏置?总之,希望能在理解谱范数本质的基础上考虑去应用 mup,知道哪些 module 该套,哪些module 不该套,以及对于这些并非普通 linear 的 module 如何应用 mup 策略。</p>]]></content>
<summary type="html">确实有用</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="LLM" scheme="https://anti-entrophic.github.io/tags/LLM/"/>
</entry>
<entry>
<title>残差连接的数学视角(一):mHC</title>
<link href="https://anti-entrophic.github.io/posts/10060.html"/>
<id>https://anti-entrophic.github.io/posts/10060.html</id>
<published>2026-04-07T08:24:19.000Z</published>
<updated>2026-04-20T06:47:50.162Z</updated>
<content type="html"><![CDATA[<h1 id="Residual-Connection"><a href="#Residual-Connection" class="headerlink" title="Residual Connection"></a>Residual Connection</h1><p>残差连接我们都知道:</p><script type="math/tex; mode=display">x_{l+1} = x_l + \mathcal{F}(x_l, \mathcal{W}_l)</script><p>其中 $x_l$ 表示输入 $l$ 层的 hidden state,$W_l$ 表示第 $l$ 层的权重。</p><p>最早看到对于残差连接的理解,就是保证回传的梯度不会消失,因为:</p><script type="math/tex; mode=display">\begin{aligned}\frac{\partial x_{l+1}}{\partial x_l} &= I + \frac{\partial \mathcal{F}(x_l, \mathcal{W}_l)}{\partial x_l} \\\frac{\partial \mathcal{L}}{\partial x_l} &= \frac{\partial \mathcal{L}}{\partial x_{l+1}} \frac{\partial x_{l+1}}{\partial x_{l}} \\&= \frac{\partial \mathcal{L}}{\partial x_{l+1}} (I + \frac{\partial \mathcal{F}_l}{\partial x_l})\end{aligned}</script><p>我们可以看成,梯度也有一个类似的残差流。或者说:</p><script type="math/tex; mode=display">\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{L}}\prod_{i = l}^{L-1} (I + \frac{\partial \mathcal{F}_i}{\partial x_i})</script><p>这里的 $\prod_{i = l}^{L-1} (I + \frac{\partial \mathcal{F}_i}{\partial x_i})$ 不会像 $\prod_{i = l}^{L-1} \frac{\partial \mathcal{F}_i}{\partial x_i}$ 一样快速坍缩,因此保住了梯度。这种保梯度的观点当然没错,ResNet 本身可能也就是这个观点。</p><h1 id="Hyper-Connection"><a href="#Hyper-Connection" class="headerlink" title="Hyper Connection"></a>Hyper Connection</h1><p>在架构大炼金中,简单的 residual connection 也没有幸免,迎来了自己的升级版:</p><script type="math/tex; mode=display">x_{l+1} = \mathcal{H}_l^{\text{res}}x_l + \mathcal{H}_l^{\text{post}} \mathcal{F}(\mathcal{H}_l^{\text{pre}}x_l, \mathcal{W}_l)</script><p>这里,hidden state $x_l$ 的维度从 $C$ 被扩展为了 $n \times C$。 $\mathcal{H}_l^{\text{res}}\in \mathbb{R}^{n \times n}$ 负责混合残差流的 $x_l$ 的 $n$ 个通道;$\mathcal{H}_l^{\text{pre}}$ 和 $\mathcal{H}_l^{\text{post}}$ 则扮演了类似与 MLP 的 up/down projection 的作用。总而言之,是一种廉价的增加复杂度的方法,以及又一次缺乏数学约束的炼金。</p><p>当我们展开 HC,自然会看到它的问题:</p><script type="math/tex; mode=display">x_L = (\prod_{i=l}^{L-1} \mathcal{H}_l^{\text{res}})x_l + \sum_{i=1}^{L-1} (\prod_{j=i+1}^{L-1} \mathcal{H}_j^{\text{res}}) \mathcal{H}_i^{\text{post}} \mathcal{F}(\mathcal{H_i^{\text{pre}}x_i, \mathcal{W}_i})</script><p>(这里的连乘都默认递降,$\prod_{i=l}^{L-1} \mathcal{H}_l := \mathcal{H}_{L-1} \mathcal{H}_{L-2} \cdots \mathcal{H}_l$)</p><p>显然,梯度中的 $I$ 会变成 $\prod_{i=l}^{L-1} \mathcal{H}_l^{\text{res}}$,再没有显式结构能够阻止坍缩,会导致训练的不稳定。</p><div class="note warning flat"><p>事实上,HC 这块有很多变体,但我没有深入了解,仅当成一个整体用于抛砖引玉。下一篇文章在介绍 AttnRes 时,其中会有详细的对比。</p></div><h1 id="Manifold-Constrained-Hyper-Connections-mHC"><a href="#Manifold-Constrained-Hyper-Connections-mHC" class="headerlink" title="Manifold-Constrained Hyper-Connections (mHC)"></a>Manifold-Constrained Hyper-Connections (mHC)</h1><p>为了解决这个问题,mHC 就对 $\mathcal{H}_l^{\text{res}}$ 加了一个流形约束,希望保住残差连接本身保梯度的性质,或者我更喜欢说,保住 LLM 喜欢的那个偏置。mHC 将算子约束为双随机矩阵,满足:</p><script type="math/tex; mode=display">\mathcal{H}_l^{\text{res}} \textbf{1}_n = \textbf{1}_n, \, \textbf{1}_n^T \mathcal{H}_l^{\text{res}} = \textbf{1}_n^T, \, \mathcal{H}_l^{\text{res}} \geq 0</script><p>其构成的流形 $\mathcal{M}^{\text{res}}$ 被称为 Birkhoff 多面体。当通道数 n=1 时,这里的条件退化为 $\mathcal{H}_l^{\text{res}} = 1$, 与 $\mathcal{H}_l^{\text{res}}=I_n$ 保持一致。</p><p>这一双随机矩阵约束被视为是比 $\mathcal{H}_l^{\text{res}}=I_n$ 更好的。文中的 Section 4.1 提到的第三个性质中就道出了这一选择的数学本质:“置换矩阵的凸组合”。</p><h2 id="置换矩阵"><a href="#置换矩阵" class="headerlink" title="置换矩阵"></a>置换矩阵</h2><p>想象 mHC 中有 3 条通道,置换操作允许我们交换这 3 条通道的信息:</p><script type="math/tex; mode=display">P = \begin{pmatrix} 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix}</script><p>置换矩阵构成一个群,称为对称群 $S_n$。显然, 它是正交群 $O_n$ 的离散子群,因为 $P^TP=I$ 总成立,只不过,你没办法交换 0.5 个通道,置换群 $S_n$ 是不连续的。</p><h2 id="Birkhoff-von-Neumann-Theorem"><a href="#Birkhoff-von-Neumann-Theorem" class="headerlink" title="Birkhoff-von Neumann Theorem"></a>Birkhoff-von Neumann Theorem</h2><p>显然,我们不光需要置换矩阵,更需要一种类似通道信息混合的操作,执行非单一置换矩阵的“软交换”,就好像是我们想把第一个通道的 0.3 交换到第二个通道去,0.4 交换到第三个通道去。</p><p>根据 Birkhoff-von Neumann 定理,这种<strong>置换矩阵的凸组合,就构成了双随机矩阵;置换矩阵的凸包,就构成了双随机矩阵流形</strong>,即 Birkhoff 多面体。</p><p>从数学角度来看,双随机流形构成了一个半群,因为它不存在逆元,我们没法从通道混合后的分离出原来的输入。举个很直观的例子,对于:</p><script type="math/tex; mode=display">x=\begin{pmatrix}1\\0\end{pmatrix},\qquady=\begin{pmatrix}0\\1\end{pmatrix}</script><p>这两个输入,经过双随机矩阵</p><script type="math/tex; mode=display">A=\begin{pmatrix}\frac12&\frac12\\\frac12&\frac12\end{pmatrix}</script><p>后,</p><script type="math/tex; mode=display">Ax=Ay=\begin{pmatrix}\frac12\\\frac12\end{pmatrix}</script><p>也就是说,我们根据结果没法分辨输入是 $x$ 还是 $y$ 了。这里的双随机矩阵 $A$ 不可逆。其实哪怕可逆也不行,逆并不一定也是一个双随机矩阵。事实上,当且仅当 $A\in \mathcal{M}$ 是置换矩阵时,才有 $A^{-1} \in \mathcal{M}$。</p><p>可以这样直观理解:半群和信息的单向流动是相契合的,通道信息混合本身就是一种熵增操作。</p><h2 id="mHC-的做法"><a href="#mHC-的做法" class="headerlink" title="mHC 的做法"></a>mHC 的做法</h2><p>mHC 的 $\mathcal{H}_l^{\text{res}}$ 一开始是经过一个无约束的线性投影生成的</p><script type="math/tex; mode=display">\tilde{\mathcal{H}}_l^{\text{res}} = \alpha_l^{\text{res}} \cdot \text{mat}(\vec{x}_l' \phi_l^{\text{res}}) + b_l^{\text{res}}</script><p>有点麻烦,我直接贴论文吧:</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBP89pz71aqhCG2wnubC_n8kbx6Q_9mAAC9gtrG6FSgEamIVkUpJm2MAEAAwIAA3kAAzoE.png" width="400px" /><p style="font-size: 10px;"> 原论文 Section 4.2</p></center><p>在 mHC 中,这里采用了Sinkhorn-Knopp 算法,将生成的无约束的 $\tilde{\mathcal{H}}_l^{\text{res}}$ 再投影为一个双随机矩阵。这是一个迭代算法,mHC 采用了 20 步。尽管进行了kernel优化,但总体而言,mHC 引入了 6.7% 的时间延迟。</p><h2 id="Semi-group-的困境"><a href="#Semi-group-的困境" class="headerlink" title="Semi group 的困境"></a>Semi group 的困境</h2><p>正如我们在<a href="https://anti-entrophic.github.io/posts/10059.html">从李群的视角看 ROPE 旋转位置编码</a>那篇blog中提到的,正交群的旋转矩阵 $R(t)$ 是由 反对称矩阵 $X$ 通过矩阵指数生成的:$R = e^{tX}$。不禁要问,我们能找到双随机矩阵流形的生成元 $Q: \mathcal{H}=e^{tQ}$ 吗?</p><p>答案是有的,而且这个 $Q$ 意外的简单,是 Markov 转移速率矩阵,满足:</p><script type="math/tex; mode=display">Q\textbf{1} = \textbf{1}Q = 0</script><p>对 $ \mathcal{H}\textbf{1} = \textbf{1}, \textbf{1}^T\mathcal{H}=\textbf{1}^T$ 代入后求导即可</p><p>既然我们已经知道了双随机矩阵 $\mathcal{H}$ 的生成元是 $Q$,我们直接让模型不要输出 $\mathcal{H}$ 而是生成 $Q$ 不就可以了吗?再通过矩阵指数得到 $\mathcal{H}$ 即可:</p><script type="math/tex; mode=display">\mathcal{H} = e^{tQ} = I + tQ + \frac{1}{2!}(tQ)^2 + \frac{1}{3!}(tQ)^3 + \cdots</script><p>然而,这里的 $\mathcal{H}$ 不是群而是半群,整个 $\mathcal{M}$ 是一个凸包而不是一个光滑流形,因此我们没有办法用生成元法生成所有双随机矩阵,事实上有相当一部分 $\mathcal{H}$ 没有办法生成,是很难受的一个点。</p><h1 id="mHC-的变体"><a href="#mHC-的变体" class="headerlink" title="mHC 的变体"></a>mHC 的变体</h1><p>目前已经能找到一些 mHC 的变体,太卷了。</p><h2 id="mHC-lite"><a href="#mHC-lite" class="headerlink" title="mHC-lite"></a>mHC-lite</h2><p>ArXiv: <a href="https://arxiv.org/abs/2601.05732">mHC-lite: You Don’t Need 20 Sinkhorn-Knopp Iterations</a></p><p>它的想法是,与其把一个随机矩阵投到 Birkhoff 凸包内,不如我们将这个凸包的所有顶点作为一个基底,也就是所有的 $n!$ 种置换矩阵,然后直接学一个坐标 $a_k$,使得</p><script type="math/tex; mode=display">\mathcal{H}^{\text{res}} = \sum_{i=1}^{n!}a_i P_i,\quad a_i \geq 0, \quad \sum_i a_i = 1</script><p>最后这一步用一个 softmax 来解决。</p><p>这个问题在我看来有以下两个缺陷:</p><ul><li><p>基底的数量是 $n!$。尽管在 mHC 的 setting 下,$n=4$ 是可以接受的,但是会无法 scaling 更大的 $n$</p></li><li><p>过参数化。一个 $n \times n$ 的双随机矩阵,它的真实自由度其实只有 $(n-1)^2$,因为确定了前 $n-1$ 行和列,最后一行和一列就自动确定了。这里的问题是,mHC-lite 要求输出 $n!$ 个坐标参数,但实际的自由度只有 $(n-1)^2$。还是同样的问题,当 $n=4$ 时差距不严重,只是 $(4-1)^2=9$ 和 $4!=24$。但如果 $n=8$ 则有 $(8-1)^2 = 49$ 与 $8!=40320$ 的巨大差距。参数完全冗余。</p></li></ul><h2 id="KromHC"><a href="#KromHC" class="headerlink" title="KromHC"></a>KromHC</h2><p>ArXiv: <a href="https://arxiv.org/pdf/2601.21579">KromHC: Manifold-Constrained Hyper-Connections with Kronecker-Product Residual Matrices</a></p><p>一篇有些偷换概念的文章。它们选了一个带 Kronecker 乘法结构的子半群作为目标,认为可以通过学习一组小双随机矩阵的参数,再 kronecker 积乘起来得到一个大的双随机矩阵。</p><p>说的更具体一点,以 $n=8$ 为例,由于 $8 = 2 \times 2 \times 2$,就可以构造 3 个小的双随机矩阵 $U_i$,每个都是 $S_2$ 的凸组合,最后令 $\mathcal{H}^{\text{res}} = U_3 \otimes U_2 \otimes U_1$</p><p>显然,这里只能得到一个 $\mathcal{M}$ 的子集,并非所有的双随机矩阵都有这样的 Kronecker 积乘法结构。它的自由度只有 $3 \times (2-1)^2 = 3 \ll (8-1)^2 = 49$。至于这种结构偏置 LLM 是否接受,就是另话了。直觉上来看,这种约束似乎过强了。</p><h2 id="sHC"><a href="#sHC" class="headerlink" title="sHC"></a>sHC</h2><p>ArXiv: <a href="https://arxiv.org/pdf/2603.20896">Beyond the Birkhoff Polytope: Spectral-Sphere-Constrained Hyper-Connections</a></p><p>好文章。</p><h3 id="流程"><a href="#流程" class="headerlink" title="流程"></a>流程</h3><p>原来 mHC 的双随机矩阵约束满足:</p><script type="math/tex; mode=display">\mathcal{H}_l^{\text{res}} \textbf{1}_n = \textbf{1}_n, \, \textbf{1}_n^T \mathcal{H}_l^{\text{res}} = \textbf{1}_n^T, \, \mathcal{H}_l^{\text{res}} \geq 0</script><p>而 sHC 抛弃了最后的非负约束,替换为:</p><script type="math/tex; mode=display">\mathcal{H}_l^{\text{res}} \textbf{1}_n = \textbf{1}_n, \, \textbf{1}_n^T \mathcal{H}_l^{\text{res}} = \textbf{1}_n^T, \, \|\mathcal{H}_l^{\text{res}}\|_{\text{op}} = 1</script><p>这样的话,由于矩阵乘积满足</p><script type="math/tex; mode=display">\|\mathcal{H}_1\mathcal{H}_2\|_{\text{op}} \leq \|\mathcal{H}_1\|_{\text{op}}\|\mathcal{H}_2\|_{\text{op}} =1</script><p>并且,又因为 $\textbf{1}$ 是特征向量,所以 $\lVert \mathcal{H}_1\mathcal{H}_2\rVert_{\text{op}} \geq 1$,所以有:</p><script type="math/tex; mode=display">\|\mathcal{H}_1\mathcal{H}_2\|_{\text{op}} = 1</script><p>所以这个约束矩阵集合对乘法封闭。它实际上比 mHC 范围更大。</p><p>具体的求法比较复杂。作者定义</p><script type="math/tex; mode=display">J = \frac{1}{n} \textbf{1} \textbf{1}^T, \quad \mathbb{Z}_n = \{\mathcal{H} \in \mathbb{R}^{n \times n}: \mathcal{H}\textbf{1} = 0, \textbf{1}^T \mathcal{H} = 0\}</script><p>则:</p><script type="math/tex; mode=display">\mathcal{H}^{\text{res}} = J + \mathcal{H}^{\text{disp}}</script><p>其中 $\mathcal{H}^{\text{disp}} \in \mathbb{Z}_n$</p><p>接下来的关键一步是,由于 $\lVert \mathcal{H}^{\text{res}} \rVert_{\text{op}} = \max(\lVert J \rVert_{\text{op}}, \lVert \mathcal{H}^{\text{disp}}\rVert_{\text{op}})$,而 $\lVert J\rVert_{\text{op}}=1$,所以:</p><script type="math/tex; mode=display">\|\mathcal{H}^{\text{res}}\|_{\text{op}} = 1 \Longleftrightarrow \|\mathcal{H}^{\text{disp}}\|_{\text{op}} \leq 1</script><p>把约束转换到了 $\mathbb{Z}$ 空间上,太天才了。</p><p>然后对 $\mathcal{H}^{\text{disp}}$ 做 SVD,令:</p><script type="math/tex; mode=display">\mathcal{H}^{\text{disp}} = U\Sigma V^T</script><p>$\mathcal{H}^{\text{disp}}$ 的行空间和列空间都必须落在 $\textbf{1}^\perp$ 这个 $(n-1)$-维子空间里。于是作者固定了一个 $\textbf{1}^\perp$ 的正交基 $U_Z \in \mathbb{R}^{n \times (n-1)}$,写成:</p><script type="math/tex; mode=display">U = U_ZU_{\text{core}}, \quad V = U_Z V_{\text{core}}</script><p>其中</p><script type="math/tex; mode=display">U_{\text{core}}, V_{\text{core}} \in \mathbb{R}^{(n-1) \times (n-1)}</script><p>是正交矩阵。最终:</p><script type="math/tex; mode=display">\mathcal{H}^{\text{disp}} = (U_ZU_{\text{core}})\Sigma( U_Z V_{\text{core}})^T</script><p>只要 $\Sigma$ 的奇异值都落在 $[-1, 1]$ 内,就有:</p><script type="math/tex; mode=display">\|\mathcal{H}^{\text{disp}}\|_{\text{op}} \leq 1</script><h3 id="数学本质"><a href="#数学本质" class="headerlink" title="数学本质"></a>数学本质</h3><p>我们可以将 $J$ 视作 $J = uu^T, u:= \frac{1}{\sqrt{n}} \textbf{1}$,因为:</p><script type="math/tex; mode=display">\mathcal{H} \textbf{1} = \textbf{1} \Longleftrightarrow \mathcal{H} u = u</script><p>所以 $\text{span}\{u\}$ 是一个 $\mathcal{H}$ 的不变子空间。又因为:</p><script type="math/tex; mode=display">\textbf{1}^T \mathcal{H} = \textbf{1}^T \Longleftrightarrow \mathcal{H}^T u = u</script><p>任取 $x \in u^\perp$,即 $u^Tx = 0$,则:</p><script type="math/tex; mode=display">u^T(\mathcal{H}x) = (\mathcal{H}^Tu)^Tx = u^Tx = 0</script><p>所以 $\mathcal{H}x \in u^\perp$,$u^\perp$ 也是 $\mathcal{H}$ 的不变子空间,所以我们有</p><script type="math/tex; mode=display">\mathcal{H} \sim \begin{pmatrix}1 & 0 \\0 & M\end{pmatrix}</script><p>比起 mHC 的通道混合,sHC 更像是把空间分解成了 $\mathbb{R}^n = \text{span}\{\textbf{1}\} \oplus \textbf{1}^\perp$,前半部分做恒等映射,后半部分做谱约束的变换。</p><p>sHC 的全部自由度,都在后半部分 $(n-1) \times (n-1)$ 的谱范数约束的单位球内。</p><p>在这种视角下,$\mathcal{H}$ 对乘法封闭是很自然的,因为</p><script type="math/tex; mode=display">\mathcal{H}_1\mathcal{H}_2 \sim \begin{pmatrix}1 & 0 \\0 & M_1M_2\end{pmatrix}</script><p>而 $\lVert M_1M_2\rVert \leq \lVert M_1\rVert \lVert M_2\rVert \leq 1$,所以 $\mathcal{H}_1\mathcal{H}_2$ 自然仍满足约束。</p><p>考虑 sHC 的约束实际上等价于考虑这个 $M$。同样是 $(n-1)^2$ 的自由度,这当然不是巧合,我觉得 sHC 找到的数学结构其实已经很本质了。</p><h3 id="半群与最大子群"><a href="#半群与最大子群" class="headerlink" title="半群与最大子群"></a>半群与最大子群</h3><p>事实上,更准确地说,我们可以将任意 $M$ 视作极分解:</p><script type="math/tex; mode=display">M = QP</script><p>其中 $Q = UV^T \in O(n-1)$ 是正交部分;而</p><script type="math/tex; mode=display">P = V\Sigma V^T</script><p>是一个对称半正定矩阵,并且由于 $\lVert M\rVert_{\text{op}} \le 1$,它满足</p><script type="math/tex; mode=display">0 \preceq P \preceq I</script><p>这里,$Q$ 可以理解为纯旋转的可逆部分,而 $P$ 则代表了不可逆的收缩部分。也就是说,sHC 在 $\textbf{1}^\perp$ 上允许的并不是任意线性变换,而是一个<strong>正交变换之后再接一个收缩算子</strong>。从这个角度看,sHC 的约束实际上把残差流分成了两部分:</p><ul><li>一个可逆、守范数的“群”部分;</li><li>一个不可逆、带耗散的“半群”部分。</li></ul><p>这和 mHC 的情况很不一样。sHC 其实已经把问题从“如何在 Birkhoff 多面体里找一个好点”,转化成了“如何在一个收缩半群里找一个好算子”。它不再依赖 Birkhoff 多面体这种有棱有角的凸几何,而是等价于在 $(n-1)$-维空间上考虑</p><script type="math/tex; mode=display">\{M\in \mathbb{R}^{(n-1)\times(n-1)}: \|M\|_{\text{op}}\le 1\}</script><p>这个算子范数球。<br>当然,半群终究不是群。它虽然有一个很自然的最大子群,如果我们进一步要求 $M$ 可逆且 $M^{-1}$ 仍在球内,那么由于</p><script type="math/tex; mode=display">\|M\|_{\text{op}} \leq 1, \quad \|M^{-1}\|_{\text{op}} \leq 1</script><p>$M$ 的所有奇异值都只能等于 1,所以 $M$ 必须是一个正交矩阵。相当于我们在找:</p><script type="math/tex; mode=display">\mathcal{H} \sim\left\{\begin{pmatrix}1 & 0\\0 & Q\end{pmatrix}: Q\in O(n-1)\right\},</script><p>但大部分元素仍然是不可逆的,残差流混合天然是半群而非群这个麻烦仍然存在。不过,我觉得 sHC 能告诉我们,约束未必非得来自双随机矩阵,也可以来自更本质的谱稳定结构,就已经足够好了。我很喜欢。</p><h1 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h1><p>回顾这些 mHC 变体,它们其实都是在考虑该选择什么样的代数结构作为 Residual Stream 中多条通道之间的动态混合的约束。如果我们按照可行解空间从大到小排序,那会是 sHC > mHC/mHC-lite > KromHC。</p><p>我还是一贯的观点,这些工作其实很难说谁好谁坏,主要还是看 LLM 更喜欢哪种代数偏置,而 LLM 的偏好目前来说似乎是缺乏第一性的探测方法的。因此,对于各种方法本质代数结构的捕捉,并进行分类整理,我觉得是一件很重要的事情,是我们为数不多总结 LLM 偏置的方法之一。</p><p>下一篇可能想介绍一下 Kimi 的 AttnRes。和 mHC 这种在单层内横向执行混合的方法不同,AttnRes 是在深度上纵向进行混合,两者是相对比较正交的,但同样是很精彩的工作。</p><h1 id="How-to-site"><a href="#How-to-site" class="headerlink" title="How to site"></a>How to site</h1><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">@misc{zhang2026residual,</span><br><span class="line"> author={Yechen Zhang},</span><br><span class="line"> title={残差连接的数学视角(一):mHC},</span><br><span class="line"> year={2026},</span><br><span class="line"> month={April},</span><br><span class="line"> url={\url{https://anti-entrophic.github.io/posts/10060.html}}</span><br><span class="line">}</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html">没几个月就已经卷完了 ┐(´д`)┌</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/tags/Model-Structure/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="LLM" scheme="https://anti-entrophic.github.io/tags/LLM/"/>
</entry>
<entry>
<title>从李群的视角看 ROPE 旋转位置编码</title>
<link href="https://anti-entrophic.github.io/posts/10059.html"/>
<id>https://anti-entrophic.github.io/posts/10059.html</id>
<published>2026-03-25T12:38:31.000Z</published>
<updated>2026-04-10T13:09:55.381Z</updated>
<content type="html"><![CDATA[<p>最近比较闲,学了点数学,李群李代数什么的,然后一下就联想到了ROPE,所以这次系统地从数学视角回顾了一下。我也是边写边查漏补缺,所以我尽量写得容易阅读并且适合学习,不仅仅是当成自己的笔记,可以带点科普性质,希望不那么熟悉相关内容的同学也能流畅阅读。</p><h1 id="群"><a href="#群" class="headerlink" title="群"></a>群</h1><p>学过群论的同学都知道,我们可以用群来描述旋转。我们以三角形的旋转群 $C3$ 作为例子,它包含了旋转 0°、旋转 120° 和 旋转 240° 这三个操作。群需要满足四个性质:</p><ol><li><p>封闭性:先旋转 120°,再旋转 120°,最后等于旋转 240°,这个操作仍然是原操作集合的一员。</p></li><li><p>单位元:需要有一个无用的操作(旋转 0°)</p></li><li><p>逆元:任何操作都有逆操作,比如旋转 120°的逆操作是旋转 240°,相当于反向转 120°</p></li><li><p>结合律:$(A\cdot B)\cdot C = A \cdot (B \cdot C)$</p></li></ol><p>但是,这是一种离散的有限群。如果我们考虑一个圆的旋转群呢?我们可以旋转任意小的角度。这种“连续”的群就称为李群</p><h1 id="李群"><a href="#李群" class="headerlink" title="李群"></a>李群</h1><p>从定义上看,李群是一个集合,它既是一个群、又是一个光滑流形。比如说二维平面上的所有旋转矩阵 $\begin{pmatrix}\cos\theta & -\sin\theta \\ \sin\theta & \cos \theta\end{pmatrix}$ 就构成一个李群 $SO(2)$</p><p>到这里大家肯定知道,为什么我会说从李群的视角看 ROPE 了。</p><h2 id="李代数"><a href="#李代数" class="headerlink" title="李代数"></a>李代数</h2><p>为了研究这个连续的李群,我们选择研究它在单位元附近的微元。毕竟只需要掌握了旋转 0.00…01°,多次应用就可以得到一个更大的旋转。我们将李群 $G$ 在单位元 $e$ 处的切空间 $T_eG$ 称为李代数 $\mathfrak{g}$。李代数就像是李群的微分版本,我们以旋转群 $SO(N)$ 为例来说明。</p><p>任取一个李群元素 $R(t) \in SO(N)$,我们现在考虑李代数的元素 $X$。如果说李群元素 $R(t)$ 是旋转角度,那么李代数元素 $X$ 就可以理解为角速度。有方程:</p><script type="math/tex; mode=display">\dot{R}(t) = R(t) \cdot X</script><p>其中 $X$ 代表的是相对于单位元的变化率,对于 $R(t)$ 而言,它的变化率需要把 $X$ 迁移到自己的坐标系上。</p><p>解微分方程,我们得到 $R(t) = e^{tX}$。</p><p>所以,对于 $SO$ 这个李群,它的李代数 $\mathfrak{so}$ 中的元素的矩阵指数,就是李群中的元素。有一种积分的感觉。我们可以对 $R(t)^TR(t) = I$ 关于时间 $t$ 求导:</p><script type="math/tex; mode=display">\begin{aligned}\dot{R}(t)^TR(t) + R(t)^T\dot{R}(t) = \dot{I} = 0\end{aligned}</script><p>代入 $t=0$,又因为我们知道 $R(0) = I$,$X = \dot{R}(0)$,所以:</p><script type="math/tex; mode=display">\begin{aligned}\dot{R}(0)^T \cdot I + I^T \cdot \dot{R}(0) &= 0 \\X^T + X &= 0 \\X^T &= -X\end{aligned}</script><p>所以,旋转群 $SO(N)$ 的正交性可以导出它的李代数 $\mathfrak{so}(N)$ 的元素必须是反对称的。</p><h2 id="李括号"><a href="#李括号" class="headerlink" title="李括号"></a>李括号</h2><p>李代数是一个向量空间,基础配置只有向量元素间的加法。但对于群而言,加法不是自然的,因为加法总满足交换律( $+1+2$ 与 $+2 +1$ 相同),群上的操作更像是不可交换的乘法,例如先翻转再旋转,和先旋转再翻转,会得到不同的结果——这就相当于是先乘翻转算子 $X$ 还是先乘旋转算子 $Y$ 的区别。因此,我们希望给李代数也安装上乘法。不过挪威数学家<a href="https://en.wikipedia.org/wiki/Sophus_Lie">索菲斯 · 李</a>告诉我们,我们只需要衡量乘法操作不可交换的程度即可,称为李括号:</p><script type="math/tex; mode=display">[X, Y] = XY - YX</script><p>怎么理解这个李括号呢?按照我的理解,它其实只是群操作在微观的切空间的投影。</p><p>设我们有一个群 $G$,并选定一个群元素 $g \in G$。我们定义一个映射 $\Psi_g: G \rightarrow G$,它的作用规则是:</p><script type="math/tex; mode=display">\Psi_g(x) = gxg^{-1}</script><p>它被称为共轭作用。该映射固定了单位元,$\Psi_g(e) = geg^{-1} = e$,我们可以研究单位元附近加上微扰的变化,也就是单位元 $e$ 处的切空间,即李代数 $\mathfrak{g}$。</p><p>不妨将这个微元当作一条经过单位元 $e$ 的曲线 $x(t)$,其中 $x(0) = e$,它的导数速度向量是 $Y = \dot{x}(0) \in \mathfrak{g}$</p><p>考虑它的共轭作用</p><script type="math/tex; mode=display">\Psi_g(x(t)) = g \cdot x(t) \cdot g^{-1}</script><p>对 $t$ 求导并令 $t=0$:</p><script type="math/tex; mode=display">\frac{d}{dt} \big|_{t=0} (g \cdot x(t) \cdot g^{-1}) = g \cdot \dot{x}(0) \cdot g^{-1} = gYg^{-1}</script><p>这个新向量 $gYg^{-1}$ 也在李代数 $\mathfrak{g}$ 中,因为李代数对共轭封闭。</p><p>我们可以将这个映射视为一个线性算子,实际上它被称为伴随表示:</p><script type="math/tex; mode=display">\text{Ad}_g(Y) = gYg^{-1}</script><p>伴随表示 $\text{Ad}_g$ 描述了群元素 $g$ 如何旋转或拉伸整个李代数空间。</p><p>由于李群是光滑的,我们可以进一步求导,考虑 $g$ 的微扰如何影响伴随表示算子。同样,我们视 $g$ 为随时间 $s$ 变化的曲线 $g(s)$,且 $g(0) = e$,速度向量为 $X= \dot{g}(0) \in \mathfrak{g}$</p><p>对 $\text{Ad}_{g(s)}(Y) = g(s)Yg(s)^{-1}$ 关于 $s$ 求导,并令 $s=0$</p><script type="math/tex; mode=display">\begin{aligned}\frac{d}{ds}\big|_{s=0}(g(s)Yg(s)^{-1}) &= \dot{g}(0)Yg(0)^{-1} + g(0)Y(\frac{d}{ds}g(s)^{-1})\big|_{s=0} \\&= X\cdot Y \cdot I + I \cdot Y \cdot (-g(0)^{-1}\dot{g}(0)g(0)^{-1}) \\&=XY-YX \\&= [X, Y]\end{aligned}</script><p>所以,李括号正是共轭作用的微分的微分。</p><p>我可以举几个实际的例子来帮助理解。以 $SO(2)$ 为例,若 $g$ 是李群的元素,代表旋转 $\theta$ 度;$Y$ 是李代数的元素,代表了角速度; 那么 $\text{Ad}_g(Y)$ 就是这个角速度经过了 $g$ 的变换,是从 $\theta$ 开始的角速度。$\text{ad}_XY$ 就是当 $\theta$ 无穷小时,这个角速度的变化率。</p><p>你会发现在 $SO(2)$ 中,好像 $\text{Ad}_g(Y)$ 这步是无意义的,不管从哪个 $\theta$ 开始,角速度都一样。因为 $SO(2)$ 是一个交换群(阿贝尔群),我们先转 $30°$ 再转 $60°$ 与先转 $60°$ 再转 $30°$ 是一样的。对于交换群而言,共轭作用等于恒等映射:</p><script type="math/tex; mode=display">\Psi_g(h) = h</script><p>从而,$\text{Ad}_g$ 变成了恒等算子;$\text{ad}_X(Y) = 0$,李括号为 0,交换操作无影响。</p><p>但在 $SO(3)$ 中就不一样了。我们假设 $g$ 是绕 Z 轴旋转 $\theta$, $Y$ 是绕 $X$ 轴的角速度,那么 $\text{Ad}_g(Y)$ 就会把这个绕 $X$ 轴的角速度绕 $Z$ 轴旋转 $\theta$。如果 $\theta=90°$,那么就变成了绕 $Y$ 轴的角速度,角速度矢量旋转了 $90°$。$\text{ad}_X(Y)$ 衡量的就是,当 $\theta$ 无限小时,这个角速度矢量的变化,也可以称为角加速度。</p><p>综上,李括号可以看作是群性质在微分切空间上的投影。例如,李括号还要求具有反对称性,即 $[X, Y] = -[Y, X]$,这个也很好理解,就来源于群乘法的不可交换性嘛。自然如果李括号恒为零,那就说明是个交换群。</p><p>李括号还要求双线性。这是因为李括号所操作的李代数的元素本来就在切空间上,作为一个向量空间自然有线性,李括号的定义 $[X, Y] = XY - YX$ 自然继承这种线性结构。</p><p>最后,回顾群的性质,还有一条结合律。群的结合律能够导出,共轭作用 $\Psi_g$ 是群的自同构,即:</p><script type="math/tex; mode=display">\Psi_g(xy) = g(xy)g^{-1} = (gxg^{-1})(gyg^{-1}) = \Psi_g(x)\Psi_g(y)</script><p>并且,它的微分 $\text{Ad}_g$ 也是李代数的自同构:</p><script type="math/tex; mode=display">\begin{aligned}\text{Ad}_g([Y, Z]) &= g[Y, Z]g^{-1} = g(YZ -ZY)g^{-1} \\&= g(YZ)g^{-1} - g(ZY)g^{-1} \\&= gYg^{-1}gZg^{-1} - gZg^{-1}gYg^{-1} \\&= \text{Ad}_g(Y)\text{Ad}_g(Z) - \text{Ad}_g(Z)\text{Ad}_g(Y) \\&= [\text{Ad}_g(Y), \text{Ad}_g(Z)]\end{aligned}</script><p>我们对上式关于 $g$ 求导,其速度向量为 $X$,则左式为:</p><script type="math/tex; mode=display">\frac{d}{dt}\text{Ad}_{g}([Y, Z]) = [X, [Y, Z]]</script><p>右式为:</p><script type="math/tex; mode=display">\begin{aligned}\frac{d}{dt}[\text{Ad}_g(Y), \text{Ad}_g(Z)] &= [\frac{d}{dt}\text{Ad}_g(Y), \text{Ad}_g(Z)] + [\text{Ad}_g(Y), \frac{d}{dt} \text{Ad}_g(Z)] \\&= [[X, Y], Z] + [Y, [X, Z]]\end{aligned}</script><p>即</p><script type="math/tex; mode=display">\begin{aligned}& [X, [Y, Z]] = [[X, Y], Z] + [Y, [X, Z]] \\\Longleftrightarrow &[X, [Y, Z]] + [Y, [Z, X]] + [Z, [X, Y]] = 0\end{aligned}</script><p>上式被称为雅可比恒等式。</p><p>总结一下,李括号的双线性、反对称性、雅可比恒等式,都是群、光滑流形的属性在微观尺度下的投影罢了。</p><h1 id="ROPE"><a href="#ROPE" class="headerlink" title="ROPE"></a>ROPE</h1><p>ROPE 旋转位置编码是现在大部分大模型位置编码的选型。它用一种绝对编码的方式为隐状态赋予了相对位置编码的效果。我们以一个隐状态维度 $d=2$ 为例,对于一个在位置 $m$ 的 token,其 $q, k$ 向量会被旋转 $m\theta$ 角度:</p><script type="math/tex; mode=display">\begin{pmatrix} q_0' \\ q_1' \end{pmatrix} =\begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix}\begin{pmatrix} q_0 \\ q_1 \end{pmatrix}</script><p>这其实就是复数乘法 $q’ = q \cdot e^{i m\theta}$。由此,当我们计算 Attention Score 时,有:</p><script type="math/tex; mode=display">\begin{aligned}q^T_mk_n &= (R_mq)^T(R_nk) = q^TR_m^TR_nk \\&= q^TR_{-m}R_nk \\&= q^TR_{n-m}k\end{aligned}</script><p>也就是说,我们施加的绝对位置编码,在 QK 交互时的影响却仅限于相对位置关系,自动获得了一种平移不变性。</p><p>推广到 $d$ 维,RoPE 选择的是一种 两维 为一组的构造方式,形成了了一个巨大的块对角旋转矩阵,我们记为花体的 $\mathcal{R}_m \in SO(d)$:</p><script type="math/tex; mode=display">\mathcal{R}_m = \begin{pmatrix}\cos m\theta_1 & -\sin m\theta_1 & 0 & 0 &\\\sin m\theta_1 & \cos m\theta_1 & 0 & 0 &\\0 & 0 & \cos m\theta_2 & -\sin m\theta_2 &\\0 & 0 & \sin m\theta_2 & \cos m\theta_2 &\\&&&&\ddots\end{pmatrix}</script><p>这里就很值得说道说道了,这其实是一个非常强的归纳偏置,主要包含两个部分:一个是将旋转限定在了李群 $SO(N)$ 的最大环面上;另一个是限制了每个子群的旋转角度。</p><h2 id="极大环面"><a href="#极大环面" class="headerlink" title="极大环面"></a>极大环面</h2><p>第一点,我们可能会怀疑,RoPE 这样二维为一组地组织旋转,是不是引入了损耗?只限定在一些子平面上?</p><p>极大环面指的是,包含在李群 $G$ 中的、最大的连通的阿贝尔(交换)子群。例如 $SO(3)$ 的最大环面是 $SO(2)$,$SO(4)$ 的最大环面是 $SO(2) \times SO(2)$</p><p>极大环面定理告诉我们,李群 $G$ 中的任意一个元素 $g$,都共轭于极大环面 $T$ 中的某个元素 $t$</p><script type="math/tex; mode=display">g = h\cdot t \cdot h^{-1}, \quad \text{where} \quad t\in T, h\in G</script><p>也就是说,不管一个高维旋转矩阵看起来多么复杂、多么纠缠,只要我们做一次基变换,它本质上都只是在若干个独立平面上的简单旋转。</p><p>用我们更熟悉的矩阵语言来说,就是对于任意一个实反对称矩阵 $X \in \mathfrak{so}(d)$,一定存在一个正交基变换矩阵 $U\in O(d)$,使得 $U^TXU$ 变成 $2\times 2$ 的块对角矩阵:</p><script type="math/tex; mode=display">\Lambda = \begin{pmatrix}0 & -\lambda_1 & & & \\\lambda_1 & 0 & & & \\& & 0 & -\lambda_2 & \\& & \lambda_2 & 0 & \\& & & & \ddots\end{pmatrix}</script><p>其中 $\pm i \lambda_k$ 是 $X$ 的特征值</p><p>我们不妨假设先取 $SO(d)$ 上的两个旋转矩阵 $R_m, R_n \in SO(d)$,根据极大环面定理,我们总有:</p><script type="math/tex; mode=display">R_m = \text{exp}(mX) = \text{exp}(mU\Lambda U^T) = U\text{exp}(m\Lambda) U^T = U \mathcal{R}_m U^T</script><p>我们把这个全空间旋转代入 Attention 中求相关性分数的公式:</p><script type="math/tex; mode=display">\begin{aligned}\text{Score}(i, j) &= (R_iW_qx_i)^T(R_j W_k x_j) \\ &= x_i^T W_q^T R_i^T R_j W_k x_j \\&= x_i^T W_q^T (U \mathcal{R}_{-i}U^T)(U \mathcal{R}_j U^T) W_k x_j \\&= x_i^T (W_q^T U) \mathcal{R}_{j-i} (U^T W_k) x_j\end{aligned}</script><p>显然,我们只要令 $\tilde{W}_q = U^TW_q, \tilde{W}_k = U^TW_k$ 即可。原本 $W_q$ 和 $W_k$ 就没有任何约束,基变换早就隐式地融入到 $\tilde{W}_q$ 与 $\tilde{W}_k$ 中了。</p><h3 id="平移不变性"><a href="#平移不变性" class="headerlink" title="平移不变性"></a>平移不变性</h3><p>RoPE 有一个强调的好处是,它刻画了相对位置关系,具有平移不变性。</p><p>从群的角度来看,这代表着对应的旋转操作是可交换的。我们在全空间 $SO(d)$ 里取两个生成元 $X$ 和 $Y$,当且仅当 $[X, Y]=0$ 时,才有 $\text{exp}(X)\text{exp}(Y) = \text{exp}(X+Y)$</p><p>我需要强调,RoPE 的平移不变性并非来源于它将旋转操作限制到了 $SO(d)$ 的极大环面上,从而利用了 $SO(2)$ 的可交换性。实际上,对于 1D 的文本序列,$R_m$ 和 $R_n$ 只涉及一个生成元 $X \in \mathfrak{so}(d)$, 旋转矩阵分别是 $R_m = \text{exp}(mX)$ 和 $R_n = \text{exp}(nX)$。而 $[mX, nX] = 0$ 恒成立,因此总有:</p><script type="math/tex; mode=display">R_m^TR_n = \text{exp}(-mX)\text{exp}(nX) = \text{exp}((n-m)X) = R_{n-m}</script><p>它天然就具有完美的平移不变性!这种平移不变性来源于李群本身,而不是来源于 RoPE 的二维块对角结构。</p><h3 id="极大环面的优势"><a href="#极大环面的优势" class="headerlink" title="极大环面的优势"></a>极大环面的优势</h3><p>所以 RoPE 强制使用 $\mathcal{R}_m$ 而不是 $R_m$,在我看来实际上并没有限制表达能力,反而是帮模型省去了显式计算 $U$ 的过程,将这部分工作转嫁给了本就需要学习的 $W_q$ 和 $W_k$,把原来 $SO(d)$ 的 $\frac{d(d-1)}{2}$ 个自由度锁定为了对角线上的 $\frac{d}{2}$ 个自由度(甚至这 $\frac{d}{2}$ 个频率也是不学的,相当于给模型加了一个极其强大的结构化先验,迫使模型去适应这个空间)。不知道苏神当时有没有思考这些问题,还是只是靠个人的数学直觉觉得这种做法不错。</p><p>不过,我不确定 2D 序列如图像、视频的位置编码是怎么实现的,如果位置是 $(x, y, t)$ 的话,那它们采取的生成元是不同的?这个时候可能就是不可交换的,需要用到 RoPE 的极大环面偏置。我需要去请教一下专家,再完成剩下的部分。</p><h2 id="锁定的特征频率"><a href="#锁定的特征频率" class="headerlink" title="锁定的特征频率"></a>锁定的特征频率</h2><p>因此,RoPE 的归纳偏置的主要争议点,我觉得就在于它强行限定了各个子群的旋转频率上,实际上也就是强行设定了整个空间的特征值。其实这很难说是好事还是坏事,正因为我们强行编码了一组频率,我们才能通过调整特定频率来实现长文外推等功能,否则如果我们学的是一个 $SO(d)$ 的旋转矩阵,我们想要 scaling 的话根本无法下手。</p><h1 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h1><p>我们从代数结构的角度分析了一下 RoPE,理解了 RoPE 引入的关键偏置其实是固定的旋转频率。对我而言,这个认识就像是一个照妖镜,能够看出很多文章的修改究竟是否本质,还是改动实际上并没有增加表达力。我之后会再写一篇文章,基于本篇文章的认识,系统地讨论一些 RoPE 的变体甚至是所有位置编码的统一视角。包括但不限于 FoPE、GRAPE 和 PaTH,特别是今天(4.10)和 xr 哥聊,被介绍了 GRAPE。其实,感兴趣的同学可以直接把 这篇 blog 和 感兴趣的文章 同时喂给 AI,让它锐评一下就可以了。</p><h1 id="How-to-site"><a href="#How-to-site" class="headerlink" title="How to site"></a>How to site</h1><figure class="highlight text"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">@misc{zhang2026liegrouprope,</span><br><span class="line"> author={Yechen Zhang},</span><br><span class="line"> title={从李群的视角看 ROPE 旋转位置编码},</span><br><span class="line"> year={2026},</span><br><span class="line"> month={Mar},</span><br><span class="line"> url={\url{https://anti-entrophic.github.io/posts/10059.html}}</span><br><span class="line">}</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html">Not that complex though</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Positional Embedding" scheme="https://anti-entrophic.github.io/tags/Positional-Embedding/"/>
<category term="Geometry" scheme="https://anti-entrophic.github.io/tags/Geometry/"/>
</entry>
<entry>
<title>Part II of Symplectic Geometry - Properties</title>
<link href="https://anti-entrophic.github.io/posts/10058.html"/>
<id>https://anti-entrophic.github.io/posts/10058.html</id>
<published>2026-03-23T05:16:17.000Z</published>
<updated>2026-03-23T10:45:22.034Z</updated>
<content type="html"><![CDATA[<h1 id="李代数"><a href="#李代数" class="headerlink" title="李代数"></a>李代数</h1><p>李代数指的是一个向量空间 $\mathfrak{g}$ 配备一个李括号 $[\cdot, \cdot]: \mathfrak{g} \times \mathfrak{g} \rightarrow \mathfrak{g}$,满足以下三条公理:</p><ol><li><p>双线性</p></li><li><p>反对称: $[x, y] = -[y, x]$</p></li><li><p>雅可比恒等式</p><script type="math/tex; mode=display">[x, [y, z]] + [y, [z, x]] + [z, [x, y]] = 0</script></li></ol><p>括号指的是,吃进去两个对象,吐出来同类型的第三个对象。</p><p>李括号的输入是两个向量场 $X, Y$,输出一个新的向量场 $Z = [X, Y]$,它的严谨定义为两个向量场代表的作用在光滑函数 $f$ 上的微分算子的交换子:</p><script type="math/tex; mode=display">[X, Y](f) \equiv X(Y(f)) - Y(X(f))</script><p>它衡量 “先沿 $X$ 走再沿 $Y$ 走” 与 “先沿 $Y$ 走再沿 $X$ 走” 的差异。如果 $[X, Y]=0$,则说明这两个场是<strong>对易</strong>的,它们的流可以像经纬网一样铺开;如果非零,说明流形在弯曲或者扭转,使得路径无法闭合。</p><h1 id="辛同构"><a href="#辛同构" class="headerlink" title="辛同构"></a>辛同构</h1><p>设 $(M, \omega)$ 为辛流形。由于 $\omega$,它自然诱导了一个从切丛到余切丛的同构 $\flat: TM \to T^*M$</p><p>即对于任意一个向量 $X \in V$,我们可以用 $\omega$ 把 $X$ 变成一个对偶向量 $\alpha$,记作</p><script type="math/tex; mode=display">\iota_X \omega = \alpha</script><p>具体为 $\alpha(Y) = \omega(X, Y), \forall Y \in V$</p><h2 id="哈密顿向量场"><a href="#哈密顿向量场" class="headerlink" title="哈密顿向量场"></a>哈密顿向量场</h2><p>给定一个光滑函数 $H$,它的外微分 $dH$ 是一个 1-形式:</p><script type="math/tex; mode=display">dH = \sum \frac{\partial H}{\partial x^i} dx^i</script><p>我们将辛同构应用到 $dH$ 上,定义向量场 $X_H$ 满足:</p><script type="math/tex; mode=display">\iota_{X_H}\omega = dH \quad \Rightarrow \quad \omega(X_H, Y) = dH(Y), \quad \forall Y</script><script type="math/tex; mode=display">\begin{aligned}(X_H)^TJY &= (\nabla H)^T Y \\(X_H)^TJ &= (\nabla H)^T \\J^T X_H &= \nabla H \\X_H &= J\nabla H\end{aligned}</script><p>我们会发现,$X_H$ 就对应了哈密顿方程的解。这是哈密顿方程的几何来源。</p><script type="math/tex; mode=display">X_H = \begin{pmatrix} 0 & I \\ -I & 0 \end{pmatrix} \begin{pmatrix} \frac{\partial H}{\partial q} \\ \frac{\partial H}{\partial p} \end{pmatrix} = \begin{pmatrix} \frac{\partial H}{\partial p} \\ -\frac{\partial H}{\partial q} \end{pmatrix}</script><h1 id="泊松括号"><a href="#泊松括号" class="headerlink" title="泊松括号"></a>泊松括号</h1><p>泊松括号由辛形式定义:</p><script type="math/tex; mode=display">\{f, g\} \equiv \omega(X_f, X_g)</script><p>它输入两个光滑函数 $f, g$,输出一个新的函数 $h = \{f, g\}$。我们可以推导如下性质:</p><script type="math/tex; mode=display">\omega(X_f, X_g) = (\iota_{X_f}\omega)(X_g) = df(X_g) = X_g(f)</script><p>这里最后一步是外微分的定义,我们分别展开左式与右式:</p><p>左式:</p><script type="math/tex; mode=display">df(X_g) = (\sum_i \frac{\partial f}{\partial x_i} dx^i)(\sum_{j} X_g^j \frac{\partial}{\partial x^j})</script><p>由于 $dx^i (\frac{\partial}{\partial x^j}) = \delta_j^i$,所以</p><script type="math/tex; mode=display">df(X_g) = \sum_{i,j} \frac{\partial f}{\partial x_i} X_g^i \delta_j^i = \sum_i X_g^i \frac{\partial f}{\partial x^i}</script><p>右式:</p><script type="math/tex; mode=display">X_g(f) = (\sum_j X_g^j \frac{\partial}{\partial x^j})f = \sum_j X_g^j \frac{\partial f}{\partial x^j}</script><p>所以 左式 $=$ 右式</p><h2 id="坐标表达式"><a href="#坐标表达式" class="headerlink" title="坐标表达式"></a>坐标表达式</h2><p>我们还可以代入坐标计算</p><script type="math/tex; mode=display">\begin{aligned}\{f, g\} &= \omega(J \nabla f, J \nabla g) = (J \nabla f)^T J (J \nabla g) = \nabla f^T J^T J J \nabla g \\&= \nabla f^T J \nabla g = \sum(\frac{\partial f}{\partial q_i} \frac{\partial g}{\partial p_i} - \frac{\partial f}{\partial p_i} \frac{\partial g}{\partial q_i})\end{aligned}</script><p>它衡量的是物理量 $f$ 随着由 $g$ 产生的流是如何变化的。例如若 $\{f, H\} = 0$($H$ 是哈密顿量/能量),则 $f$ 相对于 $H$ 是一个守恒量。</p><h1 id="李代数同态"><a href="#李代数同态" class="headerlink" title="李代数同态"></a>李代数同态</h1><div class="note warning flat"><p>该部分内容有丶丶超纲,我不会太严谨,因为我自己也没完全理解。但我需要这一部分来继续解释辛几何。</p></div><p>可以证明,映射 $f \mapsto X_f$ 保持括号结构 $X_{\{f, g\}} = -[X_f, X_g]$ (这里的负号是因为我之前定义了 $\iota_{X_H}\omega = dH$,只影响符号)</p><p>我们知道 $\iota_{X_{\{ f, g \}}} = d\{f, g\}$,所以我们欲证:</p><script type="math/tex; mode=display">\iota_{-[X_f, X_g]} = d\{f, g\}</script><h2 id="Cartan-公式"><a href="#Cartan-公式" class="headerlink" title="Cartan 公式"></a>Cartan 公式</h2><p>对于任意向量场 $X$ 和任意微分形式 $\omega$,有:</p><script type="math/tex; mode=display">\mathcal{L}_X\omega = d(\iota_X \omega) + \iota_X(d\omega)</script><p>其中 $\mathcal{L}_X$ 是李导数,定义为</p><script type="math/tex; mode=display">\mathcal{L}_X \omega = \frac{d}{dt} \big|_{t=0} \phi_t^* \omega</script><p>它描述了几何对象沿着向量场 $X$ 产生的流 $\phi_t$ 的变化率,它是对时间的导数。</p><p>Cartan 公式等于是在说,要算李导数,不用真的去解微分方程求流 $\phi_t$,只需要求两个静态的空间操作 ($d$ 和 $\iota_X$) 即可。</p><h3 id="Cartan-公式的证明"><a href="#Cartan-公式的证明" class="headerlink" title="Cartan 公式的证明"></a>Cartan 公式的证明</h3><p>证明略。需要分别证明对于 0-形式 和 1-形式成立,并且由于对外积满足莱布尼茨法则,所以它对所有形式都成立。</p><h2 id="证明"><a href="#证明" class="headerlink" title="证明"></a>证明</h2><p>利用 Cartan 公式,因为辛形式是闭形式,满足 $d\omega = 0$, 且 $X_f$ 是哈密顿向量场,有 $\iota_{X_f} \omega = df$,所以:</p><script type="math/tex; mode=display">\mathcal{L}_{X_f}\omega = d(df) + 0 = 0</script><p>利用李导数与内乘的交换关系:</p><script type="math/tex; mode=display">\iota_{[X, Y]} = [\mathcal{L}_X, \iota_Y] = \mathcal{L}_X \iota_Y - \iota_Y \mathcal{L}_X</script><p>将其作用在 $\omega$ 上:</p><script type="math/tex; mode=display">\begin{aligned}\iota_{[X_f, X_g]} \omega &= \mathcal{L}_{X_f} (\iota_{X_g} \omega) - \iota_{X_g} (\mathcal{L}_{X_f} \omega) \\&= \mathcal{L}_{X_f}(dg) - \iota_{X_g}(0) \\&= \mathcal{L}_{X_f}(dg)\end{aligned}</script><p>因为 $d$ 和 $\mathcal{L}_X$ 是可交换的 $(d \mathcal{L}_X = \mathcal{L}_X d)$;且对于函数 $g$,$\mathcal{L}_{X_f}g = X_f(g)$(在光滑函数上应用李导数的定义可推),所以:</p><script type="math/tex; mode=display">\mathcal{L}_{X_f}(dg) = d(\mathcal{L}_{X_f}g) = d(X_f(g))</script><p>我们在泊松括号章节曾推导性质 $\{f, g\} = -X_f(g)$,所以:</p><script type="math/tex; mode=display">\iota_{[X_f, X_g]} \omega = -d\{f, g\}</script><p>得证</p><h1 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h1><p>综上,辛形式提供了一个非退化的映射 $f \xrightarrow{\omega} X_f$,使得函数不再只是一个静止的值,而是存在动力学联系的流,为函数空间也赋予了一个李代数结构。</p><p>这是辛几何能用于动力学研究的数学基底。</p><p>不过,辛流形必须是偶数维的,那奇数维流形上就没有类似的动力学结构了吗?有,只不过那就变成了另一套被称作“接触几何”的东西。</p><p>辛几何虽然物理意义更直观,但接触几何描述的耗散系统可能更符合我的需求,所以辛几何就到此为止吧~ 学到了很多新知识。</p>]]></content>
<summary type="html">Probably make sense</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Symplectic Geometry" scheme="https://anti-entrophic.github.io/tags/Symplectic-Geometry/"/>
</entry>
<entry>
<title>Part I of Symplectic Geometry - The Basis</title>
<link href="https://anti-entrophic.github.io/posts/10057.html"/>
<id>https://anti-entrophic.github.io/posts/10057.html</id>
<published>2026-03-18T11:26:05.000Z</published>
<updated>2026-03-23T10:42:50.775Z</updated>
<content type="html"><![CDATA[<h1 id="双线性形式"><a href="#双线性形式" class="headerlink" title="双线性形式"></a>双线性形式</h1><h2 id="定义"><a href="#定义" class="headerlink" title="定义"></a>定义</h2><p>设 $V$ 是实数域 $\mathbb{R}$ 上的一个向量空间。定义 $V$ 上的一个双线性形式是一个函数 $B: V \times V \rightarrow \mathbb{R}$,它接受两个向量作为输入,输出一个实数;并且对每一个变量都保持线性,即任意向量 $u, v, w \in V$ 和任意标量 $\lambda \in \mathbb{R}$,满足:</p><p>对第一个变量线性</p><script type="math/tex; mode=display">\begin{aligned}B(u+v, w) &= B(u, w) + B(v, w) \\B(\lambda u, w) &= \lambda B(u, w)\end{aligned}</script><p>对第二个变量线性</p><script type="math/tex; mode=display">\begin{aligned}B(u, v+w) &= B(u, v) + B(u, w) \\B(u, \lambda v) &= \lambda B(u, v)\end{aligned}</script><p>数学上,我们常常以 “形式” 来称呼代数中输出为标量的函数。注意双线性形式并不限于实数域,任意域 $\mathbb{F}$ 均可定义双线性形式,如复数域 $\mathbb{C}$,那会导向不同的数学分支,暂不讨论。</p><h2 id="矩阵表示"><a href="#矩阵表示" class="headerlink" title="矩阵表示"></a>矩阵表示</h2><p>我们在 $V$ 中选定一组基 $\{e_1, e_2, \cdots, e_n\}$,将 $u, v$ 写成基向量的组合</p><script type="math/tex; mode=display">u = \sum x_i e_i,\quad v = \sum y_j e_j</script><p>根据双线性性质,将 $B(u, v)$ 展开</p><script type="math/tex; mode=display">B(u, v) = B(\sum_i^n x_i e_i, \sum_j^n y_j e_j) = \sum_i^n \sum_j^n x_i y_j B(e_i, e_j)</script><p>定义一个 $n \times n$ 的矩阵 $A$,使得 $A_{ij} = B(e_i, e_j)$,上式可以写成漂亮的矩阵乘法形式:</p><script type="math/tex; mode=display">B(u, v) = u^TAv</script><p>所以,其实每种双线性形式都对应了一个方阵 $A$;反过来每一个方阵 $A$ 都定义了一个双线性形式。研究双线性形式的性质,本质上就是研究这个矩阵 $A$ 的性质。</p><h2 id="黎曼几何与辛几何的分野"><a href="#黎曼几何与辛几何的分野" class="headerlink" title="黎曼几何与辛几何的分野"></a>黎曼几何与辛几何的分野</h2><p>对于任意一个双线性形式 $B$,我们可以构造两个新的双线性形式:</p><script type="math/tex; mode=display">\begin{aligned}B_{sym}(u, v) &= \frac{1}{2}(B(u, v) + B(v, u)) \\B_{skew}(u, v) &= \frac{1}{2}(B(u, v) - B(v, u))\end{aligned}</script><p>显然</p><script type="math/tex; mode=display">\begin{aligned}B_{sym}(u, v) &= B_{sym}(v, u) \\B_{skew}(u, v) &= -B_{skew}(v, u)\end{aligned}</script><p>且任意双线性形式 $B$ 可以唯一写成</p><script type="math/tex; mode=display">B(u, v) = B_{sym}(u, v) + B_{skew}(u, v)</script><h3 id="长度"><a href="#长度" class="headerlink" title="长度"></a>长度</h3><script type="math/tex; mode=display">|u|^2 = B(u, u) = B_{sym}(u, u) + B_{skew}(u, u)</script><p>根据 $B_{skew}$ 的性质,我们很容易知道 $B_{skew}(u, u)=0$,因此 $B_{skew}$,即反对称部分对长度的贡献为 0</p><h3 id="角度"><a href="#角度" class="headerlink" title="角度"></a>角度</h3><p>当我们考虑角度时,我们需要考虑余弦定理</p><script type="math/tex; mode=display">\begin{aligned}|u-v|^2 &= |u|^2 + |v|^2 - 2|u||v|\text{cos}\theta \\&= B(u, u) + B(v, v) - 2|u||v|\text{cos}\theta\end{aligned}</script><p>对比双线性性质展开的结果</p><script type="math/tex; mode=display">B(u-v, u-v) = B(u, u) - B(u, v) - B(v, u) + B(v, v)</script><p>我们知道这里有 $B(u, v) + B(v, u) = 2|u||v|\text{cos}\theta$,而 $B(u, v)+ B(v, u) = 2 B_{sym}(u, v)$,也就是说,夹角 $\theta$ 也与 $B_{skew}$ 无关。</p><p>黎曼几何,研究的就是 $B_{skew}=0$ 时的特例。</p><div class="note success flat"><p>$B_{skew}=0$ 也是完全可以成立的,这等同于引入了一个挠率张量,在几何上表现为改变了空间的联络。在复几何、弦论当中,$B_{skew}$ 绝不是多余的。</p></div><h3 id="面积"><a href="#面积" class="headerlink" title="面积"></a>面积</h3><p>面积的几何直觉来自:自己和自己张成的面积为 0,即 $B(u, u)=0$, 而我们知道 </p><script type="math/tex; mode=display">B(u, u) = B_{sym}(u, u) + \underbrace{B_{skew}(u, u)}_{0} = B_{sym}(u, u)</script><p>因此,如果我们想要定义一个符合直觉的面积,我们就被迫要令 $B_{sym}=0$,面积就只和剩下的 $B_{skew}$ 有关</p><p>$B_{skew}$ 的反对称形式,其实代表的是面积的有向性。辛几何,研究的就是 $B_{sym} = 0$ 的特例。</p><p>综上,我们可以明白两种不同的几何底层对应的不同双线性形式</p><script type="math/tex; mode=display">\begin{aligned}B(u, v) = B(v,u) \quad& \Leftrightarrow A^T = A \\B(u, v) = -B(v,u) \quad& \Leftrightarrow A^T = -A\end{aligned}</script><div class="note success flat"><p>如果我们允许 $B_{sym}\neq 0$,将会导向一种耗散系统(多辛几何 & 接触几何)。暂不展开,但我觉得我将会不可避免地碰到它。</p></div><h1 id="辛形式"><a href="#辛形式" class="headerlink" title="辛形式"></a>辛形式</h1><p>设 $V$ 是一个 $m$ 维实向量空间,$V$ 上的一个辛形式是一个双线性形式 $\omega:V \times V \rightarrow \mathbb{R}$,满足以下两个条件:</p><ol><li>反对称性<script type="math/tex; mode=display">\forall u, v \in V, \omega(u, v) = - \omega(v, u)</script></li></ol><p>这也隐含了 $\omega(v, v) = 0$</p><ol><li>非退化性<script type="math/tex; mode=display">\text{if} \,\, \forall v \in V, \omega(u, v) = 0, \text{then} \,\, u=0</script></li></ol><h2 id="性质"><a href="#性质" class="headerlink" title="性质"></a>性质</h2><p>如果一个向量空间 $V$ 上存在辛形式,那么 $V$ 的维数必须实偶数,即 $m=2n$</p><p>证明(代数法):取 $V$ 的一组基 $\{e_1, e_2, \cdots, e_n\}$,则辛形式 $\omega$ 唯一对应一个矩阵表示 $A_{ij} = \omega(e_i, e_j)$</p><p>由反对称性,$A^T = A$,而对于行列式 $\text{det}(A) = \text{det}(A^T)$</p><p>所以,$\text{det}(A) = \text{det}(-A) = (-1)^m\text{det}(A)$</p><p>又因为非退化性,$\text{det}(A) \neq 0$,所以 $1 = (-1)^m \Rightarrow m$ 为偶数</p><h1 id="辛流形"><a href="#辛流形" class="headerlink" title="辛流形"></a>辛流形</h1><p>我们将辛向量空间推广到光滑流形上。一个辛流形是一个二元组 $(M, \omega)$,其中 $M$ 是一个偶数维的光滑流形,$\omega$ 是 $M$ 上的一个微分 2-形式</p><h2 id="微分-2-形式"><a href="#微分-2-形式" class="headerlink" title="微分 2-形式"></a>微分 2-形式</h2><p>我们任取一个 $V$ 上的非零向量 $e_1$,由于 $\omega$ 是非退化的,所以至少存在一个向量 $v \in V$,使得 $\omega(e_1, v) \neq 0$</p><p>显然我们可以缩放 $v$,找到 $f_1$ 使得 $\omega(e_1, f_1) = 1$,这样我们就可以在 $V$ 中剥离出一个 2 维空间 $W_1 = \text{span}\{e_1, f_1\}$。我们可以在 $W_1$ 的正交补空间上继续剥离,总之整个辛空间的可以看作是若干个 2 维空间的直和,基为 $\{e_1, e_2, \cdots, e_n, f_1, f_2, \cdots, f_n\}$</p><p>显然有 </p><script type="math/tex; mode=display">\omega(e_i, e_j) = 0, \quad \omega(f_i, f_j)=0, \quad \omega(e_i, f_j) = \delta_{ij}, \quad\omega(f_i, e_j) = -\delta_{ij}</script><p>得到辛形式的矩阵为</p><script type="math/tex; mode=display">J = \begin{pmatrix} 0 & I_n \\ -I_n & 0 \end{pmatrix}</script><p>对于任意向量 $u, v \in V$,我们将其分解为基坐标形式</p><script type="math/tex; mode=display">\begin{aligned}u &= \sum_{i=1}^n (u_{q_i}e_i + u_{p_i} f_i) \\v &= \sum_{j=1}^n (v_{q_j}e_j + v_{p_j} f_j)\end{aligned}</script><p>计算</p><script type="math/tex; mode=display">\omega(u, v) = \omega\left( \sum_{i} (u_{q_i} e_i + u_{p_i} f_i), \sum_{j} (v_{q_j} e_j + v_{p_j} f_j) \right)</script><p>展开后有四种项:</p><ol><li><strong>$e_i$ 与 $e_j$ 的项</strong>:$\sum u_{q_i} v_{q_j} \omega(e_i, e_j) = 0$。</li><li><strong>$f_i$ 与 $f_j$ 的项</strong>:$\sum u_{p_i} v_{p_j} \omega(f_i, f_j) = 0$。</li><li><strong>$e_i$ 与 $f_j$ 的项</strong>:<script type="math/tex; mode=display">\sum_{i,j} u_{q_i} v_{p_j} \omega(e_i, f_j)</script>由于 $\omega(e_i, f_j) = \delta_{ij}$,只有 $i=j$ 时非零。<script type="math/tex; mode=display">= \sum_{i=1}^n u_{q_i} v_{p_i}</script></li><li><strong>$f_i$ 与 $e_j$ 的项</strong>:<script type="math/tex; mode=display">\sum_{i,j} u_{p_i} v_{q_j} \omega(f_i, e_j)</script>由于 $\omega(f_i, e_j) = -\delta_{ij}$。<script type="math/tex; mode=display">= \sum_{i=1}^n -u_{p_i} v_{q_i}</script></li></ol><p><strong>合并结果</strong>:</p><script type="math/tex; mode=display">\omega(u, v) = \sum_{i=1}^n (u_{q_i} v_{p_i} - u_{p_i} v_{q_i})</script><p>能看到 $p_i$ 分量与 $q_i$ 分量总是配对的,交错项在辛形式中的贡献为 0。我们将 $u_{q_i} v_{p_i} - u_{p_i} v_{q_i}$ 记为 $(dp_i \wedge dq_i)(u, v)$, 则 $\omega(u, v) = (\sum_{i=0}^n dp_i \wedge dq_i )(u, v)$</p><p>回到 微分 2-形式,它指的就是在流形上的每一个点 $x$,都能定义一个辛形式 $\omega_x: T_xM \times T_xM \rightarrow \mathbb{R}$,衡量当前位置切空间中任意两个向量的有向面积。</p><h2 id="闭形式"><a href="#闭形式" class="headerlink" title="闭形式"></a>闭形式</h2><p>$\omega$ 的外微分等于 0, $d \omega = 0$</p><p>这个有点难理解,这一性质似乎和很多东西有关。我暂时先放在这里,等我更熟悉了以后再回来记录。</p><h1 id="相空间"><a href="#相空间" class="headerlink" title="相空间"></a>相空间</h1><p>相空间就是流形的余切丛 $T^\ast Q$,辛几何可以自然地定义在相空间上</p>]]></content>
<summary type="html">我要成为辛几何高手</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Symplectic Geometry" scheme="https://anti-entrophic.github.io/tags/Symplectic-Geometry/"/>
</entry>
<entry>
<title>Functional Analysis - Quick Explanation</title>
<link href="https://anti-entrophic.github.io/posts/10055.html"/>
<id>https://anti-entrophic.github.io/posts/10055.html</id>
<published>2026-03-12T04:39:14.000Z</published>
<updated>2026-03-13T01:35:02.124Z</updated>
<content type="html"><![CDATA[<p>泛函分析是对我影响很大的一门学科,用调侃的话来说我的感受就是:“不懂泛函的人素质品味都很差”。但是,怎么清晰地和一些不大接触数学的同学解释呢?我自己想了一个比较通俗的版本,遂分享一下。</p><h1 id="Main-Content"><a href="#Main-Content" class="headerlink" title="Main Content"></a>Main Content</h1><p>所谓泛函分析,顾名思义就是研究泛函的一门学科。泛函可以通俗地理解为是一台机器,它吃进一个元素,吐出一个数字。只有这一个要求,再无其它(通常,我们还会要求具有线性,研究“线性泛函”,但我们这部分科普可以先暂时忽略)。</p><p>举个直观例子,我们每个人都可以是一个泛函。假设我现在准备了一大堆食材,黄瓜、番茄、鸡蛋,我们每个人都可以给这些食材打分。我可能给黄瓜打 90 分,小明可能会给黄瓜打 60 分。总之,我们都是接收一个“食材”,输出一个“好吃分数”,如 小明(黄瓜)=60,这就可以称作是一个泛函了。</p><p>但是这有什么用呢?我们可以换一把“尺子”。现在,我们用食材本身作为泛函,而输出的分数则作为两者的适配程度。因为番茄和鸡蛋很搭,所以我们使用 “番茄尺子” 去衡量鸡蛋的话,得分就是 番茄(鸡蛋) = 100;但如果我们换用“黄瓜尺子”,可能搭配度差一点,得分是 黄瓜(鸡蛋) = 20</p><p>我的下一个问题是,为什么 番茄(鸡蛋) = 100,而 黄瓜(鸡蛋) = 20?它们为什么不同?你会发现,这和这两把 “尺子” 的固有属性是有关的。实际上,也就是和对应的 “番茄” 与 “黄瓜” 的固有属性是有关的(在数学上,我们把 “番茄尺子” 和 “番茄” 称为对偶)。比如番茄中含有脂溶性的番茄红素,和鸡蛋一起营养价值更高,也即我们通过“番茄尺子”,标定了鸡蛋的“油脂”属性;而黄瓜则和鸡蛋没有任何联动,也能侧面说明,鸡蛋肯定不像黄瓜一样适合做沙拉。泛函分析,目标就是通过这些尺子,进而实际上去研究所有元素的属性,比如含糖量、比如油脂的多少,等等。我们还可以把所有的食材放在一个空间里,去研究整个 “食材空间” 的属性。</p><p>而在现实世界的很多问题中,我们面对的“食材”太复杂了,复杂到了包含了无穷多的维度。我们需要这样一项技术,去研究一些 “尺子” 或 “空间”,比如深度学习的矩阵、比如量子力学的哈密顿算符。</p><ul><li><p>在深度学习中,我们的神经网络就相当于在寻找一把完美的“高级尺子(泛函)”。你给它输入一张照片(一个包含了几百万像素的复杂元素),它吐出一个数字(比如 99% 概率是猫)。神经网络训练的过程,就是在微调这把尺子的刻度。</p></li><li><p>在量子力学中,微观粒子没有确定的位置,它们以“波函数”的形式存在于无限维的空间中。物理学家使用哈密顿算符,本质上也就是我们刚才说的复杂版“机器”,它把粒子的状态吃进去,吐出来一个我们能观测到的物理量(比如能量值)。</p></li></ul><p>总结一下,在泛函分析的视角里,我们不再研究一个具体的元素,而是研究这个元素空间上的尺子、及其读数。这些我们充分认识的尺子,就是我们用以构建对整个元素空间认识的最强工具,是一个更广泛的视角。</p>]]></content>
<summary type="html">有关泛函分析的简单科普</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Optimization" scheme="https://anti-entrophic.github.io/tags/Optimization/"/>
<category term="Functional Analysis" scheme="https://anti-entrophic.github.io/tags/Functional-Analysis/"/>
</entry>
<entry>
<title>Mousse - Rectifying the Geometry of Muon with Curvature-Aware Preconditioning</title>
<link href="https://anti-entrophic.github.io/posts/10054.html"/>
<id>https://anti-entrophic.github.io/posts/10054.html</id>
<published>2026-03-11T10:06:51.000Z</published>
<updated>2026-04-01T06:05:57.763Z</updated>
<content type="html"><![CDATA[<div class="note success flat"><p>本篇博客会解读一下本人最近发表的 《<a href="https://arxiv.org/abs/2603.09697">Mousse: Rectifying the Geometry of Muon with Curvature-Aware<br>Preconditioning</a>》 这篇工作,并分享一些最近在优化器方面工作的一些心得和理解。内容包含很多主观认识,我在该领域也还不太成熟,欢迎大家与我交流。</p></div><div class="note warning flat"><p>该博客仍在施工中;This is still a Work in Progress (WIP)</p></div><h1 id="Introduction"><a href="#Introduction" class="headerlink" title="Introduction"></a>Introduction</h1><p>上学期我的大部分时间都在思考一些有关预训练的问题。虽然在我看来,学界对于预训练的热情正在逐渐褪去(当然只限于 scaling 方面,各种架构反倒是愈加层出不穷,但也从侧面反映了当前模型正在逼近其边界),不过由于我天然的松弛感,我依然在无忧无虑地阅读着各类文献。其中,<a href="https://kexue.fm/" title="科学空间">苏神的博客</a>意外成为了我最喜欢阅读的内容之一,也间接成为了我继续写博客的契机,同时我也自然而然地开始关注到 attention、优化器等领域。特别地,苏神有关 muon 的几篇文章我觉得讲的是很明白的,有很多实际训练的经验以及理论上的认识,我也自己记录了<a href="https://anti-entrophic.github.io/posts/10047.html" title="那会儿的认识还很浅薄">一篇</a>以作拙劣的模仿。感兴趣的读者可以去阅读一下苏神的文章,并在建立对 muon 的初步认识后,再来一起讨论下 Mousse 所做的内容~</p><h1 id="Why-Muon-is-Good"><a href="#Why-Muon-is-Good" class="headerlink" title="Why Muon is Good?"></a>Why Muon is Good?</h1><p>Muon 优化器本质上和 AdamW 等优化器非常不一样。它既非自适应、也完全称不上二阶,仅仅只是在特征方向上将动量一刀切,做谱归一化;硬要说的话,它反而更像 SignSGD 与 Lion 这种强制性的归一化策略。早在我们在学校上优化课时就知道,牛顿法在优化领域属于是“宇宙万法的源头”;而 Muon 意料之外的优秀表现,恰恰代表了大模型优化过程存在两种截然不同的哲学理念:由牛顿法引导的<strong>曲率派</strong> 与 以归一化主导的<strong>归一派</strong>。在我看来,后者得以异军突起的根本原因在于,我们对于曲率的预估是有误差的。从 AdamW 到 Shampoo 的升级,可以看成是一种把对 Hessian 的预估从 Diagnol 提升到了 Layer-wise Kronecker Product 。然而即便如此,不论是 EMA 也好,还是 micro batch size 也罢,都注定了这种估计是有误差的,是有损的。当这种误差大到,引入曲率没有提升甚至负提升时,那就还不如直接做一次简单的一刀切,仅仅做简单的正则化,人为设定一种强大的先验来保持训练的稳定。</p><p>优化器的公式是简单的,但是其背后的,LLM 真实的 loss landscape 却是无比复杂的。再好的数学性质,也需要与大模型的实际情况相匹配,此即所谓的归纳偏置(inductive bias),适合 LLM 的才是最好的。从结果来看,优化器的归一派无疑是 LLM 乐于接受的先验。也许是因为参数空间存在巨大的冗余,模型既可以在精准导航下找到一处崎岖的山谷,也可以在归一化的先验下,找到一个所谓各向同性的位置。它们可能在 loss 层面上表现出来一样好,但背后却落入了完全不同的损失曲面。也许这也是为什么,用 Muon 优化器训练的模型用 AdamW 续训,或是反过来,效果都不好的原因了: 两者最后收敛邻域的空间结构完全不同。AdamW 也许能停在一处悬崖,如果它的二阶矩能够识别到该方向曲率巨大而自动放缓步长,但此时 Muon 不管不顾地迈出的一步单位步长却可能摧毁模型的表现;而用 Muon 训练的模型,也许天然地就更亲睐一种谱各向同性的流形,自然地往这一子空间去收敛。</p><p>当然,除了归一化之外,Muon 最大的进步就是选择在谱空间执行归一化操作(并且引入 NS 迭代降低了计算复杂度),使其成为了一个 <strong>矩阵级</strong> 的优化器。矩阵级的优化器强于 element-wise 的优化器目前来看应该是一种共识,本质原因可能是因为矩阵算子背后代表着更好的流形约束。</p><p>而 Mousse 所做的内容就是,尝试结合 牛顿派(Shampoo) 的曲率调节 与 归一派(Muon) 的谱约束的优点,看看能不能取得更好的效果。</p><h1 id="Preliminary"><a href="#Preliminary" class="headerlink" title="Preliminary"></a>Preliminary</h1><p>我们先来思考一个简单的问题。我们都知道,梯度下降的更新公式是:</p><script type="math/tex; mode=display">\theta := \theta - \eta G</script><p>我们用一种容易理解的方式,不妨假设 $\theta$ 的量纲是 秒。而梯度本身的物理意义是,参数空间在各方向上移动一点点,对于 Loss 的影响。我们假设 Loss 的量纲是 米,那么梯度的量纲就是 米/秒</p><p>秒 := 秒 - 米/秒,量纲不同的物理量怎么能计算呢?所以里面必然存在一处,连接参数空间与梯度空间的桥梁。如果我们设参数空间为 $\mathcal{M}$,从泛函分析的角度来看,梯度吃进一个变化量 $\text{d} \theta$,然后吐出一个标量 $\text{d} L$,它可以被视为一个泛函,而梯度空间就是切空间 $T_\theta \mathcal{M}$ 的对偶空间 $T_\theta^\ast \mathcal{M}$;或者从微分几何的视角全局来看,梯度空间是参数空间的余切丛 $T^\ast \mathcal{M}$。优化器的核心任务,就是<strong>寻找一个映射,把居住在余切丛的梯度,搬运回切丛</strong>,变成更新量,即</p><script type="math/tex; mode=display">\theta := \theta - \eta \cdot \text{optimizer}(G)</script><p>这种映射自然存在其限制。直观点说就是,需要在 $G$ 的指导下,找到一个 $\Delta W$;而为了保证更新的稳定性,我们又会对 $\Delta W$ 施加范数约束。前者,我们需要考虑的是 $\min_{U \in \mathcal{T}_W\mathcal{W}} \langle G, U \rangle$;后者则是 $\Vert U\Vert \leq 1$。我们分开来讨论。</p><h2 id="Natural-Pairing"><a href="#Natural-Pairing" class="headerlink" title="Natural Pairing"></a>Natural Pairing</h2><p>前者这里有一个符号上的概念需要澄清。$G$ 生活在梯度空间 $\mathcal{G}$,而 $U$ 生活在切空间 $\mathcal{T}_W \mathcal{W}$,它们的“内积”怎么定义呢?一种不那么准确的理解方式是,我们把 $\mathcal{G}$ 和 $\mathcal{T}_W \mathcal{W}$ 都嵌入 $\mathbb{R}^N$ 空间,然后再人为选一种内积。但这样会解释不通,比如说没法用内积去诱导 $\mathcal{T}_W \mathcal{W}$ 上的算子范数,很奇怪。另一种更标准的解释是,这里的 $\langle \cdot \,, \cdot \rangle$ 指的是自然配对(Natural Pairing),它是一个更底层、不需要任何几何结构的概念。</p><p>具体而言,这里的 $G$ 可以看作是切空间的对偶空间上的一个泛函,它吃进一个切空间的元素 $\text{d}\theta$,吐出来一个标量 $\text{d}L$。这种元素与对偶的泛函天然就可以配对计算:$\langle G, U \rangle := G(U)$。正如我们开头说的量纲的例子,$U$ 的单位是 秒,$G$ 的单位是 米/秒,它们的乘积天然具有物理意义。</p><p>而内积,则是在自然配对之上,叠加了人为添加的规则,使得同空间中的任意两个向量可以执行配对操作。比如说 $\mathbb{R}^2$ 空间里的两个元素 $(1,2)$ 和 $(2,3)$,我们将 $(1,2)$ 变成一个对偶向量,比如说转置操作即可,$(1, 2)^T$ 就成为了一个 $\mathbb{R}^2$ 的泛函。$\langle (1, 2), (2, 3) \rangle = (1, 2)^T(2, 3) = 8$。也就是说:</p><script type="math/tex; mode=display">\text{Inner Product} (u, v) = \text{Natural Pairing}(\underbrace{\text{Metric}(u)}_{\text{变身为对偶}}, v)</script><p>除了转置,还有拉伸、旋转等操作,都可以成为这里构造对偶向量的 Metric。总之,我们在这里引入一点点微分几何的严谨定义。对于矩阵计算而言,我们取 $\langle G, U\rangle = \text{Tr}(G^TU)$</p><h2 id="Norm-Constraint"><a href="#Norm-Constraint" class="headerlink" title="Norm Constraint"></a>Norm Constraint</h2><p>后者就是一个范数约束。根据选择范数的不同,最终会导出不同的优化器。如果我们选择 Frobenius 范数,那么就会推出 SGD 的更新公式;如果我们选择 逐元素的 $l_\infty$ 范数,就会导出 Adam 的更新公式。这是一个统一的框架。</p><p>Muon 的更新之处在于,它选择了谱范数作为约束。谱范数的计算方式是矩阵的最大奇异值,也即最大拉伸强度,所以我在论文里都使用算子范数 $\Vert\cdot\Vert_{op}$ 来表示了。而谱范数有一个非常优良的性质是 $\Vert\cdot\Vert_{op} \leq \Vert\cdot\Vert_F$,因此它是一个更松的约束,也许能找到更优解。</p><h2 id="Muon"><a href="#Muon" class="headerlink" title="Muon"></a>Muon</h2><p>总之,上述介绍了一种看待优化器的统一视角,即将优化器视为一种映射操作。Muon 则可以表达为:</p><script type="math/tex; mode=display">\Delta W_{\text{Muon}} = \min_{U \in \mathcal{T}_W \mathcal{W}} \langle G, U \rangle, \quad \text{s.t.} \Vert U\Vert_{op} \leq 1</script><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBOstpsSZ3QY84ymuwRk2iRefmPVfEXgACOwxrGxPiiUWdMjAcXAcSagEAAwIAA3gAAzoE.png" width="400px" /><p style="font-size: 10px;"> 标准 Muon 算法</p></center><p>上述问题的解为 $\Delta W_{\text{Muon}} = -\text{msign}(G) = -UV^T$,其中 $G = U\Sigma V^T$ 是梯度的谱分解。它强行把所有的奇异值归一化到了1,这一极其强大的先验虽然加速了模型收敛,但把整个谱空间视为各向同性依然很奇怪。不同的奇异方向,它的崎岖程度显然是不同的,我们应当参考牛顿法,在崎岖的方向放缓步长,而在平坦的方向增大步长。在原作者的博客中,作者称这一步谱正交化的有效性可以归结为“神的怜爱”。当然,我理解这是一种调侃的说法,不过这确实是一种,强大但有效的先验。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBOsxpsSbLFe6CnED-VKIlxQwZG2KBcQACPQxrGxPiiUV_BdXTwIbx-AEAAwIAA3kAAzoE.png" width="600px" /><p style="font-size: 10px;"> “为何选择谱正交化 —— divine benevolence” (https://kellerjordan.github.io/posts/muon/)</p></center><h1 id="Mousse"><a href="#Mousse" class="headerlink" title="Mousse"></a>Mousse</h1><p>Mousse 就是针对上一问题的改进。Muon 的约束可以被视为 Stiefel 流形($\Delta W_{\text{Muon}}^T \Delta W_{\text{Muon}} = VU^TUV^T = I$,而我们要做的就是引入曲率:</p><script type="math/tex; mode=display">\text{vec}(\Delta W^T) H \text{vec}(\Delta W) = C</script><p>这个目标可以被视为一个启发式的规则 —— 如果曲率比较大,对应 Hessian 的元素也会比较大,那么相应的 $\Delta W$ 就应该变小;反之亦然。从几何的角度来看,我们是对坐标轴进行了一次坐标变换。既然 Muon 理论上应该应用在各向同性的谱空间上,那我们就预先做一次拉伸,将“椭球”拉成“圆球”后再执行 NS 迭代,最后反向拉伸回去即可。</p><p>上述约束我们可以将其视为 $H^{\frac{1}{2}}\text{vec}(\Delta W)$ 在做内积,是一个被 $H^{\frac{1}{2}}$ 白化的空间上的 Stiefel 流形。对 Hessian 的近似,我们选用 Shampoo-style 的<strong>层级Kronecker积近似</strong> ($H \approx (R \otimes L)^{\frac{1}{2}}$)。根据夹心公式,我们有:</p><script type="math/tex; mode=display">((R\otimes L)^{\frac{1}{2}})^{\frac{1}{2}} \text{vec}(\Delta W) = L^{\frac{1}{4}}\Delta W R^{\frac{1}{4}}</script><p>令 $P = L^{\frac{1}{4}}$, $Q=R^{\frac{1}{4}}$,则现在的优化约束变为:</p><script type="math/tex; mode=display">\min_{\Delta W}\text{Tr}(G^T \Delta W) \quad \text{s.t.} \|P\Delta W Q\|_{op} \leq 1</script><p>令 $Y = P\Delta W Q$, 那么 $\Delta W = P^{-1} Y Q^{-1}$,代回原式:</p><script type="math/tex; mode=display">\begin{aligned}\text{Tr}(G^T P^{-1} Y Q^{-1}) &= \text{Tr}(Q^{-1} G^T P^{-1} Y)\\&= \text{Tr}([P^{-T} G^T Q^{-T}]^T Y) \\&= \text{Tr}([P^{-1} G^T Q^{-1}]^T Y)\end{aligned}</script><p>令 $\tilde{G} = P^{-1}GQ^{-1}$,则原优化问题变为:</p><script type="math/tex; mode=display">\min_Y \text{Tr}(\tilde{G}^TY), \quad \text{s.t.} \|Y\|_{op} \leq 1</script><p>通过与 muon 的式子做比较,我们发现它们的形式是完全一样的,直接套用结果:</p><script type="math/tex; mode=display">\begin{aligned}Y &= -\text{msign}(\tilde{G}) \\\Rightarrow \Delta W &= -L^{-\frac{1}{4}}\text{msign} (L^{-\frac{1}{4}}GR^{-\frac{1}{4}})R^{-\frac{1}{4}}\end{aligned}</script><p>所以我们只是在 Muon 的基础上进行了一次投影与逆投影,在一个更符合 Muon 假设的空间上执行谱归一化。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBOs9psS42yi2qJOtadZgRSOhwpZUYbgACRQxrGxPiiUW2Gj7pH8jeDQEAAwIAA3kAAzoE.png" width="600px" /><p style="font-size: 10px;"> 算法实现对比 </p></center><h2 id="Experiments-amp-Ablation"><a href="#Experiments-amp-Ablation" class="headerlink" title="Experiments & Ablation"></a>Experiments & Ablation</h2><p>TODO,可以先看原文的内容~</p><h1 id="Conclusion"><a href="#Conclusion" class="headerlink" title="Conclusion"></a>Conclusion</h1><p>做这项工作的初心,还是想要尝试在 “谱约束的稳定性” 与 “二阶优化的几何自适应性” 间搭一座桥梁。我肯定不是第一个这么做的人,但我觉得这项工作肯定还是能为别人走出一点经验。正如文章开头所写的那样,曲率派与归一派,究竟两者哪个更好,其实并没有定论,至少我们的实验证明了,往 Muon 中融入曲率是有用的。并且也尽量去平衡了时间与显存的开销,使得 Mousse 成为一个实际上可被应用的 optimizer</p><p>我觉得这个方向其实还有很多可以做的。比如我在论文的 Future Works 中所写的,还有很多 shampoo 与 muon 方向的改进有集成进来的希望,也许能进一步推动 Mousse 的能力边界。以及既然 Mousse 已经结合了曲率,那它应该会在微调 AdamW 训练的模型上表现得比起 和AdamW水火不容的Muon 更好。</p><p>最后我想说,优化器的调参是一项很 tricky 的工作,我在能力与算力允许的范围内尽量做了充分的比较。我也很期待更多真实使用场景下的反馈。</p><h1 id="补充"><a href="#补充" class="headerlink" title="补充"></a>补充</h1><h2 id="NS5-of-Muon"><a href="#NS5-of-Muon" class="headerlink" title="NS5 of Muon"></a>NS5 of Muon</h2><p>现在的muon,基本上采用的都是 5 次NS迭代,其中每次都有不同的参数,在 <a href="https://github.com/microsoft/dion" title="Dion">Dion</a> 库中的参数是</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">ns_consts = [</span><br><span class="line"> (<span class="number">4.0848</span>, -<span class="number">6.8946</span>, <span class="number">2.9270</span>),</span><br><span class="line"> (<span class="number">3.9505</span>, -<span class="number">6.3029</span>, <span class="number">2.6377</span>),</span><br><span class="line"> (<span class="number">3.7418</span>, -<span class="number">5.5913</span>, <span class="number">2.3037</span>),</span><br><span class="line"> (<span class="number">2.8769</span>, -<span class="number">3.1427</span>, <span class="number">1.2046</span>),</span><br><span class="line"> (<span class="number">2.8366</span>, -<span class="number">3.0525</span>, <span class="number">1.2012</span>),</span><br><span class="line">]</span><br></pre></td></tr></table></figure><p>绘制一下它的缩放性能,会发现实际的 NS 迭代并没有忠实地还原 muon 本身的设计思路。对于过小的奇异值,NS迭代并没有能力将其拉起来。也就是说,muon 优化器实际上还有一个隐含的效果,就是选择性地忽略奇异值较小的“噪声”方向,这是由于工程实践带来的。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBFcFpJpQ9v7XxL2_n3wGnY6xlh8pWCQACFQtrGxfdOUWOtD9VMabTGwEAAwIAA3kAAzYE.png" width="400px" /><p style="font-size: 10px;">有效范围大概在 1e-2 ~ 1</p></center><p>用 dion 库训的一个非 embedding 参数量为 80M 的模型。在实际训练过程中,Mean Singular Value 可能就是 1e-2 ~ 1e-4 左右的水平</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBFcJpJpaaWy6ngdaiMO1etrOzTy75zwACFgtrGxfdOUWJIBX9JvHtKgEAAwIAA3cAAzYE.png" width="800px" /><p style="font-size: 10px;">Newton-Schulz 迭代前的动量的奇异值的均值</p></center><p>经过NS5迭代后,确实可以看到有些方向是显著偏小的。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAEBFcZpJpiJz7PIOpttXX3k3fAMo_8xXQACGgtrGxfdOUXAWSwNbarw3wEAAwIAA3cAAzYE.png" width="800px" /><p style="font-size: 10px;">Newton-Schulz 迭代后的动量的最小奇异值,有些方向并没有拉到1。不过,其实已经放大一点了,原来可能只有 1e-7。实际上执行 NS 迭代前会进行缩放,所以实际能不能拉到1只和 Condition Number 有关。</p></center><p>总而言之,感觉 NS5 忽略噪声方向也挺合理的,这可能也可以解释为什么ns系数雕花以及用单精度SVD直接算没有什么收益。</p>]]></content>
<summary type="html">新论文分享~</summary>
<category term="Optimizer" scheme="https://anti-entrophic.github.io/categories/Optimizer/"/>
<category term="Optimizer" scheme="https://anti-entrophic.github.io/tags/Optimizer/"/>
<category term="Linear Algebra" scheme="https://anti-entrophic.github.io/tags/Linear-Algebra/"/>
</entry>
<entry>
<title>Nesterov</title>
<link href="https://anti-entrophic.github.io/posts/10052.html"/>
<id>https://anti-entrophic.github.io/posts/10052.html</id>
<published>2025-11-13T13:45:37.000Z</published>
<updated>2025-11-20T05:51:42.042Z</updated>
<content type="html"><![CDATA[<p>前段时间被人问到了 Nesterov,我以为我很了解,结果发现我自己也讲不明白 Nesterov 的 “Ilya version” 是怎么来的。遂花了点时间想了一下,感觉纯是闹麻了。下面就让我来批判一下。</p><h1 id="Nesterov-的假设"><a href="#Nesterov-的假设" class="headerlink" title="Nesterov 的假设"></a>Nesterov 的假设</h1><p>Nesterov 的思想很简单,我用 $\theta + \eta M$ 处的梯度来更新 $\theta$,保证了下次更新方向是一个考虑了动量的综合,避免因为动量“冲过头”。</p><p>在实践上,我们不必真的在 $\theta + \eta M$ 处计算梯度,有一个 Ilya 提出的“近似”版本:</p><p>标准的 nesterov 更新过程:</p><script type="math/tex; mode=display">\left\{\begin{aligned}\hat{M}_{t} &= \beta_1 M_{t-1} + (1-\beta_1) G(\theta_t + \eta M_{t-1}) \\V_{t} &= \beta_2 V_{t-1} + (1-\beta_2) G^2(\theta_t)\\\theta_{t+1} &= \theta_t - \eta \frac{\hat{M}_{t}}{\sqrt{V_{t}+\epsilon}} \end{aligned}\right.</script><p>其中,二阶矩还是得用的 $\theta_t$ 这一点的?直观上考虑到它作为正则化的作用。</p><p>通过公式变换:</p><script type="math/tex; mode=display">\begin{aligned}\theta_{t+1} &= \theta_t - \eta \frac{M_{t+1}}{\sqrt{V_{t+1}+\epsilon}} = \theta_t - \eta \frac{\beta_1 M_t + (1-\beta_1) G(\theta_t + \eta M_t)}{\sqrt{V_{t+1}+\epsilon}} \\&= \theta_t - \eta \frac{\beta_1 M_t + (1-\beta_1)G(\theta_t)} {\sqrt{V_{t+1}+\epsilon}} - \eta \frac{1-\beta_1}{\sqrt{V_{t+1}+\epsilon}} [G(\theta_t + \eta M_t) - G(\theta_t)]\end{aligned}</script><p>前面两项就是标准的梯度下降。这里我们需要计算两次梯度,我们都知道,这是不可能的。</p><p>因此实际的 Nesterov 过程采用的是一种近似,更新公式是:</p><script type="math/tex; mode=display">\left\{\begin{aligned}M_{t} &= \beta_1 M_{t-1} + (1-\beta_1) G(\theta_t) \\\hat{M}_{t} &= \beta_1 M_{t} + (1-\beta_1) G(\theta_t)\\V_{t} &= \beta_2 V_{t-1} + (1-\beta_2) G^2(\theta_t)\\\theta_{t+1} &= \theta_t - \eta \frac{\hat{M}_{t}}{\sqrt{V_{t}+\epsilon}} \end{aligned}\right.</script><p>也就是说,在动量上应用了两次梯度的滑动平均。这有道理吗?我们可以考虑它相比标准 Adam 的增量:</p><script type="math/tex; mode=display">\begin{aligned}\Delta &= [(\beta_1 M_t +(1-\beta_1)G(\theta_t)) - M_t] \\&= (1-\beta_1)(G(\theta_t)-M_t)\end{aligned}</script><p>对比一下理想公式的增量,我们就能知道这里的假设是</p><script type="math/tex; mode=display">[G(\theta_t + \eta M_t) - G(\theta_t)] \approx G(\theta_t)-M_t</script><p>这个合理吗?其实很不直观,只能尝试强行解释。对左式泰勒:</p><script type="math/tex; mode=display">\eta H(\theta_t)M_t \approx G(\theta_t)-M_t</script><p>先不管有的没的可不可逆,强行推:</p><script type="math/tex; mode=display">M_t \approx (\eta H(\theta_t)+\beta_1 I)^{-1} G(\theta_t)</script><p>而我们的更新量就是:</p><script type="math/tex; mode=display">\eta M_t \approx (H(\theta_t) + \frac{\beta_1}{\eta}I)^{-1} G(\theta_t)</script><p>这是一项阻尼强度为 $\frac{\beta_1}{\eta}$ 的 hessian 曲率矫正。所以这个假设的核心是,动量可以视为一种对真实梯度的白化。</p><p>首先第一点,从理论的角度来说,这个假设有点强过头了,它不仅要求方向,实际上连大小也限制了。看起来,它直接要求了动量就是牛顿法更新量,这个我只能说哈哈了。</p><p>其次第二点,从实践的角度来说,一般 $H(\theta_t)$ 的元素的均值可能是在 $1e^{-5}$ 左右?相比于 $\beta_1 = 0.9, \eta = 1e^{-6}$ 这种水平的设置,<strong>阻尼太大了!!!</strong> 可以说是完全抹平了曲率的影响。</p><p>所以只有当上述这个很强的假设成立时,近似才成立。</p><h1 id="实践"><a href="#实践" class="headerlink" title="实践"></a>实践</h1><p>我没仔细看 NAdam 之类的论文,在我自己心里反正是很难评了。我在一些别的比较 solid 的文章里,比如说 <a href="https://arxiv.org/pdf/2509.02046," title="Fantastic Pretraining Optimizers and Where to Find Them">《Fantastic Pretraining Optimizers and Where to Find Them》</a> 也见过一些 NAdam 的结果,好像还行。我自己试过在 muon 上加上 Nesterov(用的是 <a href="https://github.com/microsoft/dion/tree/main," title="Dion: Distributed Orthonormal Updates">Dion</a> 的代码),可以说是毫无作用。</p><p>一个简单的修正方法是,再加一个超参 $k$ 上去,令 $M_t \approx k \beta_1 (\eta H(\theta_t)+\beta_1 I)^{-1} G(\theta_t)$,这样起码解放了大小的限制。回推上去,等价于 $\hat{M}_{t} = \beta_1 M_{t} + (1-\beta_1 + k\beta_1) G(\theta_t)$。</p><p>这里仍然有很多问题:</p><ul><li><p>$k$ 是否是一个可以设置的超参,在整个训练过程中近似不变?</p></li><li><p>$k$ 作为一个 scala 是否合理?还是说它必须是一个矩阵,哪怕是一个对角阵?</p></li></ul><p>如果把 $k$ 当作一个scala的话,那其实等价于调节 $\beta_1$,也就是 $\frac{\beta_1}{1-\beta_1} = \frac{\beta_1’}{1-\beta_1’ + k \beta_1’}$罢了。可以通过观测 $k$,提出一个 Nesterov 解耦的 $\beta_{11} , \beta_{12}$,其实怎么说都行,但感觉非常无聊。</p><p>不太明白 Ilya 当时怎么想的。也许读者有什么更深入的见解,欢迎来信和我分享,不吝赐教。</p>]]></content>
<summary type="html">何意味</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Optimization" scheme="https://anti-entrophic.github.io/tags/Optimization/"/>
</entry>
<entry>
<title>Functional Analysis 1 - Dual Space Solution of Optimal Control</title>
<link href="https://anti-entrophic.github.io/posts/10051.html"/>
<id>https://anti-entrophic.github.io/posts/10051.html</id>
<published>2025-11-13T05:54:01.000Z</published>
<updated>2025-11-13T14:14:01.917Z</updated>
<content type="html"><![CDATA[<h1 id="序言"><a href="#序言" class="headerlink" title="序言"></a>序言</h1><p>记录一下我在 SJTU 上的一堂好课 《Applied Functional Analysis》,前半学期是由 <a href="https://tobiasdiez.de/" title="Home Page">Diez Tobias</a> 教授讲授的 ,在很多地方给了我很大的启发。虽然我也没受过什么科班的数学训练,但是这门课很大程度上打开了我的视野,给了我另一个从泛函角度分析数学问题的视角(<code>nn.module</code> 当然可以视作一种泛函)。</p><p>本篇 blog 记录了本学期 Tobias 教授最后一堂课中的一个 example question,很有意思。</p><h1 id="Setting"><a href="#Setting" class="headerlink" title="Setting"></a>Setting</h1><p>考虑这样一个火箭升空问题:我们需要设定一个最小消耗的燃料方案,来保证火箭能够到达指定高度 $h$。为了建模该问题,我们设火箭的升空力为 $F(t)$,所处高度为 $y(t)$,一个简化的模型是:</p><script type="math/tex; mode=display">\left\{\begin{aligned}my(t)'' &= F(t) - mg \\y(0) &= 0 = y'(0) \\y(T) &= h\end{aligned}\right.</script><p>并且假设燃料消耗量为 $\int_0^T |F(t)|\text{d}t$. 这里当然做了很多简化,比如说应该是 $m(t)$ 而不是 $m$,不过不重要,只是考虑这样一种setting。</p><p>则对牛二积分可得:</p><script type="math/tex; mode=display">my(t)' - my(0)' = \int_0^t F(\tau)\text{d}\tau - mgt = my(t)'</script><p>对上式,我们再进行一次积分。注意到 </p><script type="math/tex; mode=display">\frac{d}{dt} \int_0^t (t-\tau) F(\tau)dt = (t-\tau)F(\tau) \big|_{\tau=t} + \int_0^t \frac{\partial}{\partial t} [(t-\tau)F(\tau)]d\tau = 0 + \int_0^tF(\tau)d\tau</script><p>或者分步积分也行。总之积分结果为</p><script type="math/tex; mode=display">my(t) = \int_0^t (t-\tau) F(\tau)d\tau - \frac{mg}{2}t^2</script><p>代入 $t=T$ 时刻的约束,得</p><script type="math/tex; mode=display">mh = \int_0^T (T-\tau) F(\tau)d\tau - \frac{mg}{2}T^2</script><p>以上就是我们的约束条件,整个问题可以转化为这样一个问题:</p><p>对于一个固定的时间 $T$,找到一个预算方案 $F \in \mathbb{L}^1([0, T])$,使得</p><script type="math/tex; mode=display">\begin{aligned}&\text{minimize} &\, &\int_0^T|F(\tau)|\text{d}\tau = \|F\|_1 \\&\text{subject to } &\, &mh = \int_0^T (T-\tau) F(\tau)d\tau - \frac{mg}{2}T^2\end{aligned}</script><p>若我们从泛函的角度去理解该问题,则可以考虑一个 Hilbert Space $\mathcal{H}$,令 $w = \{T-\tau\} \in \mathcal{H}$, $u=\{F(\tau)\} \in \mathcal{H}$ and $c = mh + \frac{mg}{2}T^2 \neq 0 \in \mathbb{R}$,将问题转化为 </p><script type="math/tex; mode=display">\begin{aligned}&\text{minimize} &\, &\|u\| \\&\text{subject to } &\, &\langle w| u \rangle = c\end{aligned}</script><p>以上,就是整个问题的定义,接下来的任务就是考虑怎么求解该问题。</p><div class="note success flat"><p>读者可能不熟悉泛函分析,所谓 Hilbert Space 可以理解为,带有内积诱导的范数($L^2$范数)的 Banach Space,我们常用的几何视角可以安全地搬到 Hilbert Space 上考虑。</p><p>为了处理内积,我们选择将问题放在 $\mathcal{H}$ 中。但这会导致一个问题,我们的优化目标是 $\lVert F \rVert_1$ 而非 $\lVert F \rVert_2$,本质上是不一样的。此问题先暂时按下不表,后续我们会看到如何将其放到 Dual Space 上处理。 </p></div><h1 id="Solution"><a href="#Solution" class="headerlink" title="Solution"></a>Solution</h1><h2 id="寻找解空间"><a href="#寻找解空间" class="headerlink" title="寻找解空间"></a>寻找解空间</h2><p>一个看似很简单的方法是使用 Cauchy-Schwarz 不等式,$\langle w| u \rangle = c \leq \lVert w \rVert \lVert u\rVert$,所以 $\lVert u \rVert \geq \frac{c}{\lVert w \rVert}$,直接秒了。但是我们后续会发现,它没法进一步将 $\lVert \cdot \rVert_2$ 解扩展到 $\lVert \cdot \rVert_1$了;并且 Cauchy-Schwarz 假设 $w$ 与 $u$ 是continous的,解空间被限制在 $\mathbb{L}^1$ 的子集 $\mathbb{C}([0, T])$ 上。</p><p>仍然回到泛函分析的视角,来看看怎么理解该问题。</p><p>考虑一个简单的 $(a, b, c)^T \cdot (x, y, z) = m$ 的问题,它对应的即是一个三维空间中的平面方程 $ax + by + cz = m$,法向量是 $(a, b, c)$</p><p>对于我们的 $\langle w| u \rangle = c$ 约束也是一样的,$u$ 的解空间也是一个超平面,法向量是 $w$,我们需要解决这样两个问题:</p><ul><li><p>这个超平面上有点吗?解是否存在?</p></li><li><p>我们如何描述这个超平面?</p></li></ul><p>我们在法向量 $w$ 对应的直线 $\text{span}(w)$ 上建立一个线性泛函 $\alpha \in \mathcal{H}^\mid$,使得 $\alpha(w) = c$. 应用 Hahn-Banach 定理,我们将其扩展到定义在整个 $\mathcal{H}$ 上的线性泛函 $\alpha_0$,它仍然满足 $\alpha_0(w) = c$</p><p>应用 Riesz 表示定理,将线性泛函与内积联系起来,即必然存在一个唯一的 $u_0 \in \mathcal{H}$,使得 $\forall v \in \mathcal{H}, \alpha_0(v) = \langle v | u_0 \rangle$。将其应用在向量 $w$ 上,我们可以得到 </p><script type="math/tex; mode=display">\alpha_0(w) = \langle w | u_0 \rangle = c</script><p>因此,$u_0$ 就是我们构造出来的一个,对应于约束 $\langle w | u \rangle = c$ 的可行解。</p><p>以此为锚点,我们可以简化问题:</p><script type="math/tex; mode=display">\left\{\begin{aligned}\langle w | u \rangle &= c \\\langle w | u_0 \rangle &= c\end{aligned}\right.</script><script type="math/tex; mode=display">\Rightarrow \langle w | u - u_0 \rangle = 0</script><p>所以,我们将原问题 “寻找 $u$,使得 $\langle w | u \rangle = c$” 转化为了 寻找 $u-u_0$,使得 $u-u_0\in w^{\perp}$”。我们只需要找到一个合法解,然后就可以表示整个解空间了。在这个解空间里,我们可以再考虑去找到 $\lVert \cdot \rVert_1 $ 最小的可行解。</p><h2 id="对偶空间嵌入"><a href="#对偶空间嵌入" class="headerlink" title="对偶空间嵌入"></a>对偶空间嵌入</h2><p>然而,上述推导是有问题的。首先,我们找到的解空间是 $\mathbb{L}^2$ 空间的一个子空间,很多 $\mathbb{L}^1$ 中的解早在我们将问题搬入 Hilbert Space 的那一刻就已经被扔掉了。</p><p>其次,我们之前推导全部建立在 Hilbert Space 的假设上,However,$\mathbb{L}^1([0, T])$ 恰恰不是一个 Hilbert Space,Riesz 表示定理等在此情况下都会失效,所以推导过程其实也是不成立的。</p><p>我们考虑将所有 $F \in \mathbb{L}^1([0, T])$ 嵌入到对偶空间 $\mathbb{L}^1([0, T])^\mid = \mathbb{L}^\infty([0, T])$ 上去处理,定义 </p><script type="math/tex; mode=display">\mathcal{F}(g) = \int_0^T F(\tau)g(\tau)\text{d}\tau, \mathcal{F}(\cdot) \in (\mathbb{L}^\infty([0, T]))^\mid</script><p>我们最小化的目标可以从 $\lVert F \rVert_1$ 转化为 $\lVert \mathcal{F} \rVert^\mid = \sup \{|\mathcal{F}(g)|: g \in \mathbb{L}^\infty([0, T]), \lVert g \rVert_\infty = 1 \}$</p><p>该问题的性质如下:</p><ul><li><p>该嵌入是单射(injective)的:因为如果 $\int_0^TF(\tau)g(\tau)\text{d}\tau = 0$ 对所有 $g$ 成立,则 $F$ 必须是零函数。否则我们很容易针对任意一点 $F(t_0) \neq 0$ 构造一个 bump function 来使其积分不为 0。</p></li><li><p>该嵌入是保范数(perserving the norm)的,可以证明 $\lVert \mathcal{F} \rVert^\mid = \lVert F \rVert_1$(通过分别证明 $\lVert \mathcal{F} \rVert^\mid \leq \lVert F \rVert_1$ 与 $\lVert \mathcal{F} \rVert^\mid \geq \lVert F \rVert_1$, 这是最关键的证明,后续补一下,笔记有点不清楚……)</p></li><li><p>该嵌入不是满射(not surjective)的,对偶空间要大的多</p></li></ul><p>以上性质保证了,最小化 $\lVert F \rVert_1$ 的原目标,我们可以通过最小化泛函 $\mathcal{F}$ 的范数 $\lVert \mathcal{F} \rVert^\mid$ 来实现。并且,它的解空间实际上是大于实际的解空间 $\mathbb{L}^1$ 的(也就是说,我们还有可能找到 $\mathbb{L}^1$ 之外的解)</p><p>而我们的约束条件 $\int_0^T (T-\tau) F(\tau)d\tau = c$ 可以看作是 “由 $F$ 诱导的线性泛函 $\mathcal{F}$, 作用在函数 $w$ 上” 的结果,要求 $\mathcal{F}(w) = c$</p><p>我们的目标是:</p><script type="math/tex; mode=display">\begin{aligned}&\text{minimize} &\, & \|\mathcal{F}\|^\mid \\&\text{subject to } &\, &\mathcal{F}(w) = c\end{aligned}</script><p>看起来很简单是不是?实际上就是很简单,根据 Hahn-Banach 定理的推论,我们能立刻知道所求的 $\lVert \mathcal{F}_{min}\rVert = \frac{|c|}{\lVert w \rVert_\infty} = \frac{|c|}{T}$,这就是我们所求的最小燃料消耗量。</p><p>但是方案呢?根据 Holder 不等式:</p><script type="math/tex; mode=display">|\mathcal{F}(w)| = |\int_0^T F(\tau)w(\tau)\text{d}\tau| \leq \|F\|_1 \|w\|_\infty</script><p>实际上我们考虑的就是取等的情况,$|\int_0^T F(\tau)w(\tau)\text{d}\tau| = \lVert F \rVert_1 \lVert w \rVert_\infty$,这要求我们将函数 $F(\tau)$ 的全部能量集中在 $w(\tau)$ 最大的那一点上。而 $w(\tau)$ 最大的一点即是 $t=0$ 时 $w(0) = T$</p><p>也就是说,推力 $F$ 需要集中在 $t=0$ 时施加,它是一个 $\delta$ 函数</p><script type="math/tex; mode=display">F(t) = A \cdot \delta(t)</script><p>代入约束条件,即可求得 $A = \frac{c}{T}$,也即 $F(t) = (\frac{mh}{T} + \frac{mgT}{2}) \delta(t)$</p><p>该方案的燃料消耗量 $\lVert F \rVert_1 = |\frac{c}{T}|$,和我们应用 Hahn-Banach 推论找到的最小值吻合。</p><p>严格来说,$\delta$ 函数并不是一个 $L^1$ 函数(这里不展开),这也符合我们的对偶空间嵌入策略,找到了一个 $L^1$ 之外的更好的广义解。这个解也有很好的物理解释:我们一开始就加速到最快,之后都用最快的速度来飞行。这是在这个极度简化的物理模型下的特殊情况(没有阻力、不考虑燃料的重力等等)。</p><h1 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h1><p>以上,是一个物理与泛函分析结合的例子。回到我的学习领域深度学习,我们也常常面临这种“预算分配”的问题:学习率的scheduler,batch size 的 scheduler 引出的 token 量的分配,等等。</p><p>深度学习最大的问题是,我们没有牛顿定律这样的金科玉律,根本没有办法预言我们当前的分配会产生什么样的后果。只能后验地根据 Hessian、RMS Norm 等观测值,来推测目前的 “飞行轨迹” 进行到了哪一步。硬要建模的话,我了解的也就只有 朗之万动力学 那一种;但很多实践已经证明了,朗之万动力学是有缺陷的。“有缺陷”指的是,它与 LLM 的归纳偏置不符,例如我的一个很大感触就是,噪音肯定不能建模为高斯噪声,至少也是某种重尾分布。</p><p>But who knows,这些问题牵涉的变量太多了,并且实验需要大量资源,太穷了。学习过程看起来只能建模为随机微分方程,会存在深度学习的牛顿定律或者相对论吗?只能等待着某位超级智者,未来有一天能解决深度学习领域的终极问题吧。</p>]]></content>
<summary type="html">A Tribute to an Inspiring Class</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Optimization" scheme="https://anti-entrophic.github.io/tags/Optimization/"/>
<category term="Functional Analysis" scheme="https://anti-entrophic.github.io/tags/Functional-Analysis/"/>
</entry>
<entry>
<title>Hessian 谱的 "Bulk + Spikes" 结构</title>
<link href="https://anti-entrophic.github.io/posts/10050.html"/>
<id>https://anti-entrophic.github.io/posts/10050.html</id>
<published>2025-09-12T05:19:18.000Z</published>
<updated>2026-04-09T02:49:46.620Z</updated>
<content type="html"><![CDATA[<h1 id="Hessian-谱的-“Bulk-Spikes”-结构"><a href="#Hessian-谱的-“Bulk-Spikes”-结构" class="headerlink" title="Hessian 谱的 “Bulk + Spikes” 结构"></a>Hessian 谱的 “Bulk + Spikes” 结构</h1><p>Hessian 我们都知道,一个二阶导:</p><script type="math/tex; mode=display">H_{ij} = \frac{\partial^2 L(\theta)}{\partial \theta_i \partial \theta_j}</script><p>Hessian 的特征值描述了特征向量方向的曲率。正特征值对应山谷(局部极小),负特征值对于山峰,值的绝对值大小则反映了曲率。</p><p>根据 <a href="https://scispace.com/pdf/eigenvalues-of-the-hessian-in-deep-learning-singularity-and-mdd7gvpcx6.pdf" title="Eigen Values of the Hessian in Deep Learning: Singularity and Beyond">Yann Lecun 的说法</a>,hessian非常奇异(singular),特征值大量集中在零附近,并且存在少量独立的较大的特征值。其中,将大特征值称为 Spike,将大量的零特征值称为 Bulk</p><p>并且,他还通过实验给出了一些对这一现象的直观理解:</p><ul><li>增大模型参数,特征值似乎越靠近零点。这种集中可能反映了模型的过参数化冗余。</li></ul><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAALpDmjDskCUeGfPlYy3VgM9mNF0JrDOAAKcrjEbbUEhRjuSYcLpuhwWAQADAgADeAADNgQ.png" width="400px" /><p style="font-size: 10px;">模型越大,Hessian越奇异</p></center><ul><li>数据分布越复杂,大特征值会更极端。通常认为大特征值可能与数据本身主要结构和信息密切相关。</li></ul><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAALpD2jDsxw2x_hrDG3eCmOOCKIN0BOMAAKerjEbbUEhRiNJHdWFxW9fAQADAgADdwADNgQ.png" width="800px" /><p style="font-size: 10px;">数据越难,Hessian的大特征值绝对数值越大(原论文的图也太糊了)</p></center><p>不过,这篇文章写的相当随意,就像我的博客一样,实验规模也很小。但确实是很适合我这种懒人的省流版,更详细的工作不妨参考 <a href="https://www.stat.berkeley.edu/~mmahoney/pubs/hessian-more-realistic-models-neurips21.pdf" title="NIPS21">Hessian Eigenspectra of More Realistic Nonlinear Models</a> 这篇。</p><h2 id="Wigner半圆定律"><a href="#Wigner半圆定律" class="headerlink" title="Wigner半圆定律"></a>Wigner半圆定律</h2><p>Wigner半圆定律描述了一类特定的随机矩阵特征值谱的渐进分布,具体地:</p><p>$W$ 是一个 $N \times N$ 的实对称随机矩阵(Wigner矩阵),其满足以下条件:</p><ul><li><p>对角线及上三角部分的元素 $W_{ij}(i\leq j)$ 是独立同分布 (i.i.d.) 的随机变量</p></li><li><p>这些随机变量的均值为零,即 $\mathcal{E}[W_{ij}]=0$</p></li><li><p>方差为 $\sigma^2$,即 $\mathcal{E}[W_{ij}^2]=\sigma^2$</p></li></ul><p>结论是,当矩阵维度 $N \rightarrow \infty$ 时,矩阵 $\frac{1}{\sqrt{N}}W$ 的特征值的经验谱分布依概率收敛于 Wigner 半圆分布。其概率密度函数为:</p><script type="math/tex; mode=display">\rho_{sc}(x) = \left\{\begin{aligned}&\frac{1}{2\pi \sigma^2} \sqrt{4\sigma^2-x^2}, &\quad \text{if} \, |x| \leq 2\sigma \\&0 , &\quad \text{if} \, |x| > 2\sigma\end{aligned}\right.</script><p>这个分布的支撑集在 $[-2\sigma, 2\sigma]$ 区间内,其形状顾名思义是一个半圆形。</p><p>不过,这一理论和实际相去甚远,说明 Hessian 矩阵的结构远比一个简单的 Wigner 矩阵复杂,还依赖于数据、网络架构和损失函数等等。</p><h2 id="Marchenko-Pastur-定律"><a href="#Marchenko-Pastur-定律" class="headerlink" title="Marchenko-Pastur 定律"></a>Marchenko-Pastur 定律</h2><p>描述的是协方差矩阵的特征值分布</p><p>令 $X$ 是一个 $M \times N$ 的随机矩阵,其元素 $X_{ij}$ 是独立同分布的随机变量,均值为0,方差为 $\sigma^2$。考虑由它构成的样本协方差矩阵 $S = \frac{1}{M}X^TX$</p><p>在 $M, N \rightarrow \infty$ 且比率 $\gamma = \frac{N}{M}$ 收敛到一个常数的极限下,矩阵 $S$ 的特征值经验谱分布依概率收敛于 Marchenko-Pastur 分布。其概率密度函数为:</p><script type="math/tex; mode=display">\rho_{mp}(x) = \left\{\begin{aligned}& \frac{1}{2\pi \sigma^2 \gamma x} \sqrt{(\lambda_+-x)(x-\lambda_-)} &\quad \text{if} \, x \in [\lambda_-, \lambda_+] \\&0 &\quad \text{otherwise}\end{aligned} \right.</script><p>分布的支撑集为 $\lambda_{±} = \sigma^2 (1 ± \sqrt{\gamma})^2$</p><h3 id="与-Hessian-的联系"><a href="#与-Hessian-的联系" class="headerlink" title="与 Hessian 的联系"></a>与 Hessian 的联系</h3><p>这个和 Hessian 的联系在于,我们考虑损失函数为 </p><script type="math/tex; mode=display">L(\theta) = \sum_{i=1}^n \mathcal{l}(z_i, y_i)</script><p>我们暂且将模型当成一个黑箱参数,计算 Hessian</p><script type="math/tex; mode=display">\begin{aligned}H_{mj} &= \frac{\partial^2 \mathcal{l}}{\partial \theta_m \partial \theta_j} = \frac{\partial}{\partial \theta_m} (\frac{\partial \mathcal{l}}{\partial \theta_j}) = \frac{\partial}{\partial \theta_m} (\sum_{k=1}^c \frac{\partial \mathcal{l}}{\partial z_k} \frac{\partial z_k}{\partial \theta_j}) \\&= \sum_{k=1}^c [\frac{\partial}{\partial \theta_m} (\frac{\partial \mathcal{l}}{\partial z_k}) \cdot \frac{\partial z_k}{\partial \theta_j} + \frac{\partial \mathcal{l}}{\partial z_k} \cdot \frac{\partial}{\partial \theta_m}(\frac{\partial z_k}{\partial \theta_j})]\end{aligned}</script><p>这里包含两个部分,我们分开来看。</p><p>第二部分不需要再额外求解:</p><script type="math/tex; mode=display">\text{第二部分} = \sum_{k=1}^c \frac{\partial \mathcal{l}}{\partial z_k} \cdot \frac{\partial^2 z_k}{\partial \theta_m \partial \theta_j}</script><ul><li><p>$\frac{\partial \mathcal{l}}{\partial z_k}$ 就是损失对第 $k$ 个 logit 的梯度</p></li><li><p>$\frac{\partial^2 z_k}{\partial \theta_m \partial \theta_j}$ 是第 $k$ 个 logit $z_k$ 关于参数 $\theta$ 的 Hessian 矩阵的 $(m, j)$ 元素。我们可以把它记为 $(\nabla_{\theta}^2 z_k)_{mj}$</p></li></ul><p>第一部分需要再使用链式法则:</p><script type="math/tex; mode=display">\text{第一部分} = \sum_{k=1}^c [\frac{\partial}{\partial \theta_m} (\frac{\partial \mathcal{l}}{\partial z_k})] \cdot \frac{\partial z_k}{\partial \theta_j}</script><p>这里 $\frac{\partial \mathcal{l}}{\partial z_k}$ 对 $\theta_m$ 求导时,需要注意到 $z$ 是 $\theta$ 的函数,所以应用链式法则:</p><script type="math/tex; mode=display">\frac{\partial}{\partial \theta_m} (\frac{\partial \mathcal{l}}{\partial z_k}) = \sum_{s=1}^c \frac{\partial (\frac{\partial \mathcal{l}}{\partial z_k})}{\partial z_s} \frac{\partial z_s}{\partial \theta_m} = \sum_{s=1}^c \frac{\partial^2 \mathcal{l} }{\partial z_s \partial z_k} \frac{\partial z_s}{\partial \theta_m}</script><p>看起来有点复杂,但是:</p><ul><li><p>$\frac{\partial z_s}{\partial \theta_m}$ 是雅可比矩阵 $J$ 的元素 $J_{sm}$</p></li><li><p>$\frac{\partial^2 \mathcal{l}}{\partial z_s \partial z_k}$ 是损失关于 logitis 的 Hessian 矩阵</p></li><li><p>$\frac{\partial z_k}{\partial \theta_j}$ 是雅可比矩阵 $J$ 的元素 $J_{kj}$</p></li></ul><p>所以第一部分其实可以写成</p><script type="math/tex; mode=display">\text{第一部分} = \sum_{k=1}^c \sum_{s=1}^c \frac{\partial z_s}{\partial \theta_m} \cdot \frac{\partial^2 \mathcal{l}}{\partial z_s \partial z_k} \cdot \frac{\partial z_k}{\partial \theta_j} = \sum_{k=1}^c \sum_{s=1}^c (J^T)_{ms} (H_{\mathcal{l}})_{sk} J_{kj} = J^TH_{\mathcal{l}}J</script><p>综上,Hessian可以写成</p><script type="math/tex; mode=display">\begin{aligned}H(\theta) &= \nabla^2L(\theta) = \sum_{i=1}^n \nabla^2 \mathcal{l}(z_i, y_i) \\&= \sum_{i=1}^n [J_i^T H_{\mathcal{l}, i}J_i + \sum_k (\frac{\partial \mathcal{l}}{\partial z_{ik}})(\nabla^2 z_{ik})]\end{aligned}</script><p>其中</p><ul><li><p>$J_i=\nabla_{\theta}z_i$ 是第 $i$ 个样本的 logits 关于参数 $\theta$ 的雅可比矩阵</p></li><li><p>$H_{\mathcal{l}, i} = \nabla_z^2 \nabla(z_i, y_i)$ 是损失函数 $\mathcal{l}$ 关于 logits $z_i$ 的 Hessian 矩阵。</p></li><li><p>$\frac{\partial \mathcal{l}}{\partial z_{ik}}$ 是损失对第 $k$ 个 logit 的偏导数。</p></li><li><p>$\nabla^2 z_{ik}$ 是第 $k$ 个 logit 关于参数 $\theta$ 的 Hessian</p></li></ul><h3 id="Gauss-Newton-Estimation"><a href="#Gauss-Newton-Estimation" class="headerlink" title="Gauss-Newton Estimation"></a>Gauss-Newton Estimation</h3><p>(p.s. 谁懂这两个名字一起出现的救赎感)</p><p>接下来要说明,上式中的第二部分可以近似抛弃。</p><p>首先直观地理解,对于交叉熵损失函数而言 $\frac{\partial \mathcal{l}}{\partial z_{ik}} = p_{ik} - y_{ik}$,在模型训练后期,这一项可能很小,因此可以抛弃。</p><p>然而,还有一个更神奇的东西,叫 Fisher 信息矩阵(FIM),一个定义是真实 Hessian $H$ 在模型自身预测的概率分布下的期望值。也就是说,我们假设真实标签 $y_i$ 是从模型自己的输出分布 $p(y|x_i ; \theta)$ 中采样得到的。</p><script type="math/tex; mode=display">F = \mathbb{E}_{y_i \sim p(y | x_i;\theta)}[H(\theta)]</script><p>我们来计算第二部分在这个期望下的值。我们知道 $\nabla^2 z_{ik}$ 并不依赖于标签 $y_i$,因此</p><script type="math/tex; mode=display">\mathbb{E}[\sum_{i,k}(p_{ik}-y_{ik})(\nabla^2 z_{ik})] = \sum_{i,k} \mathbb{E}[p_{ik}-y_{ik}](\nabla^2 z_{ik})</script><p>最关键的一步是,因为我们假设 $y_i$ 是从模型的概率分布 $p_i$ 中采的,所以 $y_{ik}$ 取值为1的概率是 $p_{ik}$,因此 $\mathbb{E}[p_{ik} - y_{ik}] = 0$</p><p>我们将第一部分称为 广义高斯-牛顿项 (GGN),它的期望就等于 FIM。也就是说,近似操作本身不是无脑省略,而是得到了一个非常有意义的 FIM(我甚至觉得,它比 Hessian 更好)</p><h3 id="广义高斯牛顿项与协方差矩阵"><a href="#广义高斯牛顿项与协方差矩阵" class="headerlink" title="广义高斯牛顿项与协方差矩阵"></a>广义高斯牛顿项与协方差矩阵</h3><p>扯远了,回到 Marchenko-Pastur 定律,接下来证明广义高斯牛顿项 $J^TH_{\mathcal{l}}J$ 就可以看成是一种协方差矩阵</p><p>对于交叉熵损失函数:</p><ul><li><p>$\frac{\partial \mathcal{l}}{\partial z_k} = p_k - y_k$</p></li><li><p>$(H_{\mathcal{l}})_{kj} = \frac{\partial^2 \mathcal{l}}{\partial z_k \partial z_j} = \frac{\partial p_k}{\partial z_j} = \frac{\partial \text{softmax}(z)_k}{\partial z_j} = p_k(\delta_{kj}-p_j)$</p></li></ul><p>因此,我们有:</p><script type="math/tex; mode=display">H_{\mathcal{l}} = \text{diag}(p) - pp^T</script><p>也就是说,$H_{\mathcal{l}}$ 可以视作一个均值为 $p$ 的 one-hot 向量 $y$ 的协方差矩阵,因为</p><script type="math/tex; mode=display">\mathbb{E}[yy^T] - \mathbb{E}[y]\mathbb{E}[y]^T = \text{diag}(p) - pp^T</script><p>作为一个协方差矩阵,$H_{\mathcal{l}}$ 是半正定的。又因为它本身是对称的,因此必然存在一个矩阵 $H_{\mathcal{l}}^{\frac{1}{2}}$,使得</p><script type="math/tex; mode=display">H_{\mathcal{l}} = (H_{\mathcal{l}}^{\frac{1}{2}})^TH_{\mathcal{l}}^{\frac{1}{2}}</script><p>所以,对于我们的GGN项,可以把 $H$ 分解掉</p><script type="math/tex; mode=display">G = J^T H_{\mathcal{l}} J = (H_{\mathcal{l}}^{\frac{1}{2}}J)^T H_{\mathcal{l}}^{\frac{1}{2}}J</script><p>因此,我们唯一需要做的近似就是,$\tilde{J} = H_{\mathcal{l}}^{\frac{1}{2}}J$ 是一个随机变量,由它构成的协方差矩阵 $G$ 的特征谱密度自然应该遵循 Marchenko-Pastur 定律。</p><h3 id="总结"><a href="#总结" class="headerlink" title="总结"></a>总结</h3><p>总的来说,Marchenko-Pastur 的解释能力要强得多,这从推导就能看得出来。</p><p>参数 $\gamma = \frac{N}{M}$ 通常远小于1,$N$ 是批次大小(因为 $\tilde{J}^TH_{\mathcal{l}}\tilde{J}$ 是 $\sum_{i=1}^n [J_i^T H_{\mathcal{l}, i}J_i]$ 堆叠起来的,不要忘了),$M$ 是参数维度。这也能解释 Yann Lecun 的实验结果:参数维度越大,$\gamma$ 就越小,支撑集就越小,所以 bulk 就越紧凑</p><p>更进一步,如果真实的协方差矩阵是一个低秩矩阵(信号)加上一个随机矩阵(噪声),那么其样本协方差矩阵的特征值谱就会呈现一个 MP 分布的 Bulk 和脱离 Bulk 的刺,与实验观察到的结果一致。</p>]]></content>
<summary type="html">随机矩阵是对的</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
<category term="Optimization" scheme="https://anti-entrophic.github.io/tags/Optimization/"/>
</entry>
<entry>
<title>Scaling Law</title>
<link href="https://anti-entrophic.github.io/posts/10049.html"/>
<id>https://anti-entrophic.github.io/posts/10049.html</id>
<published>2025-08-21T11:07:32.000Z</published>
<updated>2025-09-12T06:12:48.423Z</updated>
<content type="html"><![CDATA[<p>偷学了一手 scaling law,遂记录一下</p><h1 id="Critical-Batch-Size"><a href="#Critical-Batch-Size" class="headerlink" title="Critical Batch Size"></a>Critical Batch Size</h1><p>原文链接:<a href="https://arxiv.org/pdf/1812.06162" title="OpenAI">An Empirical Model of Large-Batch Training</a></p><p>在我们使用小 batch size 去计算梯度的时候,实际上可以视为对整体梯度的无偏估计,但存在一个噪声。batch size 越大,则噪声方差越小。然而,这种增长可能是有限的 —— 当 batch size 已经足够大时,增加 batch size 的收益可能较小。因此,有必要确定一个综合考虑下最合适的 batch size</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAALVpmin8CeCqQay0WOljnEfPssB5SlpAAIksDEb3alBRSXqUrD27jwcAQADAgADeAADNgQ.png" width="500px" /><p style="font-size: 10px;">小 batch size(蓝线)会产生更多的噪声,大 batch size (橙线)则相对稳定</p></center><p>对于一个模型 $\theta$,损失函数可以视为 $L(\theta) = \mathbb{E}_{x\sim \rho(x)}[L_x(\theta)]$,梯度估计为</p><script type="math/tex; mode=display">G_{est}(\theta) = \frac{1}{B} \sum_{i=1}^B \nabla_\theta L_{x_i} (\theta)</script><p>设为无偏估计,则 $\mathbb{E}_{x_1 \cdots B \sim p} [G_{est}(\theta)] = G(\theta)$</p><p>而梯度方差(更准确的来说,协方差)则为</p><script type="math/tex; mode=display">\begin{aligned}\text{Cov}_{x_1 \cdots B \sim p}(G_{est}(\theta)) &= \mathbb{E}[(G_{est}(\theta) - \mathbb{E}[G_{est}(\theta)]) (G_{est}(\theta) - \mathbb{E}[G_{est}(\theta)])^T] \\&= \mathbb{E} [(\frac{1}{B} \sum_{i=1}^B \nabla_\theta L_{x_i} (\theta) - G(\theta))(\frac{1}{B} \sum_{i=1}^B \nabla_\theta L_{x_i} (\theta) - G(\theta))^T] \\&= \frac{1}{B^2} \mathbb{E}[\sum_{i=1}^B \sum_{j=1}^B (\nabla_\theta L_{x_i} (\theta) - G(\theta))(\nabla_\theta L_{x_j} (\theta) - G(\theta))^T] \\\end{aligned}</script><p>当 $ i \neq j$ 时,由于视 $x_i$ 与 $x_j$ 为 i.i.d ,所以</p><script type="math/tex; mode=display">\begin{aligned}\mathbb{E}[(\nabla_\theta L_{x_i} (\theta) - G(\theta))(\nabla_\theta L_{x_j} (\theta) - G(\theta))^T] &= \mathbb{E}[\nabla_\theta L_{x_i} (\theta) - G(\theta)] \mathbb{E}[(\nabla_\theta L_{x_j} (\theta) - G(\theta))^T] \\&= (\mathbb{E}[\nabla_\theta L_{x_i} (\theta)] - G(\theta))(\mathbb{E}[\nabla_\theta L_{x_j} (\theta)] - G(\theta))^T \\&= 0\end{aligned}</script><p>因此当且仅当 $i = j$ 时有贡献</p><script type="math/tex; mode=display">\begin{aligned}\text{Cov}_{x_1 \cdots B \sim p}(G_{est}(\theta)) &= \frac{1}{B^2} \mathbb{E}[B \sum_{i=1}^B (\nabla_\theta L_{x_i} (\theta) - G(\theta))^2] \\&= \frac{1}{B} \Sigma(\theta)\end{aligned}</script><p>其中 $\Sigma (\theta)$ 为单条样本的梯度协方差。这告诉我们,梯度噪声方差是随 batch size 线性收敛的。</p><p>再考虑模型优化过程,讨论简化的SGD的情况。设 $G=\nabla L(\theta)$ 为真实梯度,$H$ 为 Hessian 矩阵,$\eta$ 为学习率,批次梯度为 $g = G_{est}$,做泰勒展开</p><script type="math/tex; mode=display">L(\theta -\eta g) \approx L(\theta) - \eta G^T g + \frac{1}{2} \eta^2 g^THg</script><p>两边取期望</p><script type="math/tex; mode=display">\begin{aligned}\mathbb{E}[L(\theta - \eta g)] &= \mathbb{E}[L(\theta) - \eta G^T g + \frac{1}{2} \eta^2 g^THg] \\&= \mathbb{E}[L(\theta)] - \eta G^T \mathbb{E}[g] + \frac{1}{2} \eta^2 \mathbb{E} [g^THg] \\&= L(\theta) - \eta G^T G + \frac{1}{2} \eta^2 \mathbb{E} [\text{tr}(H g g^T)] \\&= L(\theta) - \eta ||G||^2 + \frac{1}{2} \eta^2 (\text{tr}(H \mathbb{E} [ g g^T]))\end{aligned}</script><p>代入协方差 </p><script type="math/tex; mode=display">\begin{aligned}\text{Cov}(g) &= \mathbb{E}[g g^T] - \mathbb{E}[g]\mathbb{E}[g^T] \\\frac{1}{B} \Sigma &= \mathbb{E}[g g^T] - G G^T \\\mathbb{E}[g g^T] &= \frac{1}{B} \Sigma + G G^T\end{aligned}</script><p>所以 </p><script type="math/tex; mode=display">\begin{aligned}\mathbb{E}[L(\theta - \eta g)] &= L(\theta) - \eta ||G||^2 + \frac{1}{2} \eta^2 (\text{tr}(H (\frac{1}{B} \Sigma + G G^T))) \\ &= L(\theta) - \eta ||G||^2 + \frac{1}{2} \eta^2 (\frac{\text{tr}(H\Sigma)}{B} + \text{tr}(H G G^T)) \\ &= L(\theta) - \eta ||G||^2 + \frac{1}{2} \eta^2 (G^THG + \frac{\text{tr}(H\Sigma)}{B})\end{aligned}</script><p>对学习率求导</p><script type="math/tex; mode=display">\begin{aligned}\eta_{opt} &= \frac{||G||^2}{G^THG + \frac{\text{tr}(H\Sigma)}{B}} \\&= \frac{\frac{||G||^2}{G^THG}}{1 + \frac{1}{B}\frac{\text{tr}(H\Sigma)}{G^THG}}\end{aligned}</script><p>如果梯度没有噪声,则 $\Sigma=0$,最优学习率 $\eta_{max} = \frac{||G||^2}{G^THG}$</p><p>令 $B_{noise} = \frac{\text{tr}(H\Sigma)}{G^THG}$,则</p><script type="math/tex; mode=display">\eta_{opt} = \frac{\eta_{max}}{1 + \frac{B_{noise}}{B}}</script><p>此时有对 loss 的最优提升:</p><script type="math/tex; mode=display">\Delta L_{opt}(B) = \frac{\Delta L_{max}}{1 + \frac{B_{noise}}{B}}</script><p>可以看到,当 $B$ 远小于 $B_{noise}$ 时,提升 $B$ 会得到较大的线性提升;而当 $B$ 远大于 $B_{noise}$ 时,提升 $B$ 对训练的影响会变小。</p><p>如果我们用一步 $\text{d}S$ 能取得的最优 loss 下降值为 $\text{d}L$,则如果我们用 batch size 为 $B$,则需要走 $(1 + \frac{B_{noise}(s)}{B(s)})\text{d}S$ 步(注意这里应为当前 step 的函数),处理的样本数为 $\text{d}E = B(s) \text{d}S$,所以</p><script type="math/tex; mode=display">\begin{aligned}S &= \int (1 + \frac{B_{noise}(s)}{B(s)})\text{d}S \\E &= \int (B_{noise}(s) + B(s))\text{d}S\end{aligned}</script><p>现在,我们面临一个权衡:</p><ul><li><p>增加 $B$ 可以减少步数 $S$</p></li><li><p>但增加 $B$ 会增加每步的计算量 $E$</p></li></ul><p>我们引入一个 “汇率” $r$ 来量化这个权衡,意为:如果对 batch size 做小改动 $\text{d}B$,为了节省一步 $\text{d}S$ 需要付出多少计算量 $\text{d}E$</p><script type="math/tex; mode=display">r = - \frac{\frac{\partial}{\partial B} \text{d}E}{\frac{\partial}{\partial B} \text{d}S} = - \frac{\frac{\partial}{\partial B} (B_{noise}(s) + B(s))}{\frac{\partial}{\partial B} (1 + \frac{B_{noise}(s)}{B(s)})} = \frac{B^2(s)}{B_{noise}(s)}</script><p><strong>注意到</strong>,一个最优的策略,汇率 $r$ 应当是一个常数。反之,则必然存在一处汇率高于平均值,一处汇率低于平均值,存在将后者的预算分配给前者的这样一个套利空间。因此,存在理论最优的汇率 $r^*$,使得:</p><script type="math/tex; mode=display">r^* = \frac{B^2(s)}{B_{noise}(s)}</script><p>则我们能找到最优的 batch size,随着 $B_{noise}$ 指标动态变化</p><script type="math/tex; mode=display">B(s) = \sqrt{r^* B_{noise}(s)}</script><p>现在,我们将这一式子代回 $S$ 与 $E$ 的表达式,同时约定:</p><script type="math/tex; mode=display">\begin{aligned}S_{min} &= \int 1 \text{d}S \\E_{min} &= \int B_{noise} \text{d}S \\I_{sqrt} &= \int \sqrt{B_{noise(s)}} \text{d} S\end{aligned}</script><p>则代入后有:</p><script type="math/tex; mode=display">\begin{aligned}S &= S_{min} + \frac{1}{\sqrt{r^*}} I_{sqrt} \\E &= E_{min} + \sqrt{r^*} I_{sqrt}\end{aligned}</script><p>所以:</p><script type="math/tex; mode=display">\frac{S - S_{min}}{S_{min}} \cdot \frac{E-E_{min}}{E_{min}} = \frac{I_{sqrt}^2}{S_{min}E_{min}}</script><p>令右式为 $\gamma$,我们能得到一个很漂亮的幂律关系:</p><script type="math/tex; mode=display">(\frac{S}{S_{min}} - 1) = \gamma (\frac{E}{E_{min}}-1)^{-1}</script><h2 id="Tokens-Limited"><a href="#Tokens-Limited" class="headerlink" title="Tokens Limited"></a>Tokens Limited</h2><p>已知单步最优损失下降为 </p><script type="math/tex; mode=display">\Delta L_{opt}(B) = \frac{\Delta L_{max}}{1 + \frac{B_{noise}}{B}}</script><p>实际上,我们关注的并不是单步的 $\Delta L_{opt}$,而是总训练过程的 $\int_0^{S_{final}} \Delta L_{opt}(B(S))dS$</p><p>显然,这里有一个约束关系,我们考虑总tokens数恒定。</p><script type="math/tex; mode=display">\int_0^{S_{final}} B(S)dS = E_{budget}</script><p>应用拉格朗日乘子法</p><script type="math/tex; mode=display">\begin{aligned}\mathcal{L}[B(S), \lambda] &= \int_0^{S_{final}} \Delta L_{opt}(B(S))dS - \lambda (\int_0^{S_{final}} \Delta B(S)dS - E_{budget}) \\&= \int_0^{S_{final}} [\frac{\Delta L_{max}}{1 + \frac{B_{noise}}{B}} - \lambda B(S)]dS + \lambda E_{budget}\end{aligned}</script><p>取极值必要条件是 拉格朗日量 $F = \frac{\Delta L_{max}}{1 + \frac{B_{noise}}{B}} - \lambda B(S)$ 满足 欧拉-拉格朗日 方程</p><script type="math/tex; mode=display">\frac{\partial F}{\partial B} = 0</script><p>其中</p><script type="math/tex; mode=display">\begin{aligned}\frac{\partial}{\partial B} \left( \Delta L_{max} \left(1 + \frac{B_{noise}}{B}\right)^{-1} \right) &= \Delta L_{max} \cdot \left[ -1 \left(1 + \frac{B_{noise}}{B}\right)^{-2} \cdot \left(-\frac{B_{noise}}{B^2}\right) \right] \\&= \frac{\Delta L_{max} \cdot B_{noise}}{B^2 \left(1 + \frac{B_{noise}}{B}\right)^2} \\&= \frac{\Delta L_{max} \cdot B_{noise}}{\left(B \left(1 + \frac{B_{noise}}{B}\right)\right)^2} \\&= \frac{\Delta L_{max} \cdot B_{noise}}{(B + B_{noise})^2}\end{aligned}</script><p>所以 </p><script type="math/tex; mode=display">\frac{\Delta L_{max} \cdot B_{noise}}{(B + B_{noise})^2} = \lambda</script><p>最终的结果是 </p><script type="math/tex; mode=display">\begin{aligned}B_{opt}(S) &= \sqrt{\frac{\Delta L_{max}(S) \cdot B_{noise}(S)}{\lambda}} - B_{noise}(S) \\B_{noise} &= \frac{\text{tr}(H\Sigma)}{G^THG} \\\Delta L_{max}(S) &= \frac{1}{2} \frac{||G||^4}{G^THG}\end{aligned}</script><script type="math/tex; mode=display">\int_0^{S_{final}} \left( \sqrt{\frac{\Delta L_{max}(S) \cdot B_{noise}(S)}{\lambda}} - B_{noise}(S) \right) dS = E_{budget}\lambda = \left( \frac{\int_0^{S_{final}} \sqrt{\Delta L_{max}(S) \cdot B_{noise}(S)} dS}{E_{budget} + \int_0^{S_{final}} B_{noise}(S) dS} \right)^2</script><p>这里的 $\Delta L_{max}$ 可以视为随训练变化的超参。根据经验,我们知道训练后期的梯度是很小的,而cosine learning rate的话,H 会升高,所以 $\Delta L_{max}$ 会很小。这意味着在训练后期,收益已经很小了,与其增大batch size 减少噪声,不如小batch size 多探索。至少我是这么理解的</p><p>H 比较难算,但如果只近似算对角线元素计算复杂度还是和 forward-backward 一样的,可以参考AdaHessian的近似。</p><h1 id="OpenAI-Scaling-Law"><a href="#OpenAI-Scaling-Law" class="headerlink" title="OpenAI Scaling Law"></a>OpenAI Scaling Law</h1><p>原文链接:<a href="https://arxiv.org/pdf/2001.08361" title="OpenAI">Scaling Laws for Neural Language Models</a></p><p>在上述推导中,我们得到了一个稍稍有点复杂的幂律关系,然而 OpenAI 却提出上述公式可以简化为:</p><script type="math/tex; mode=display">(\frac{S}{S_{min}} - 1) (\frac{E}{E_{min}}-1) = 1</script><p>也即</p><script type="math/tex; mode=display">\frac{I_{sqrt}^2}{S_{min}E_{min}} = \frac{(\int \sqrt{B_{noise(s)}} \text{d} S)^2}{(\int 1 \text{d}S)(\int B_{noise} \text{d}S)} = 1</script><p>拟合出来的结果非常准</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAALVqWin8-oXA6SysAgAARTMYq9O5ZuI_QACJ7AxG92pQUXoQzCwP-6YXwEAAwIAA3cAAzYE.png" width="500px" /><p style="font-size: 10px;">搜索并拟合不同参数量不同 Loss 下 Emin 和 Smin 的关系</p></center><p>从数学角度理解,这其实是 Cauchy–Schwarz 不等式的一个特例。其积分形式为:对于定义在积分域 $S$ 上的任意两个实函数 $f(s)$ 和 $g(s)$,以下不等式恒成立:</p><script type="math/tex; mode=display">(\int_S f(s)g(s)\text{d}S)^2 \leq (\int_S f(s)^2 \text{d}S) (\int_S g(s)^2 \text{d}S)</script><p>等号成立当且仅当存在常数 $c$ 使得 $f(s) = c g(s)$</p><p>这意味着,OpenAI 提出的近似要成立,必须满足 $B_{noise}(s) = c$,也即 $\frac{\text{tr}(H\Sigma)}{G^THG}=c$,显然,这是一个较为苛刻的条件。直观上来看,stable 学习率设定下的训练后期可能可以满足这个条件</p><p>同时,我们令 $B_{critic} = \frac{E_{min}}{S_{min}}$,可以视作 $B_{noise}(s)$ 的面积 除以 区间长度,即平均 $B_{noise}$。这是一个 Loss 曲线的指标,代表着该任务的平均噪声强度,越大则所需的batch size越大。</p><p>OpenAI 对该指标进行了拟合,得到结论</p><script type="math/tex; mode=display">B_{critic}(L) \approx \frac{B_*}{L^{\frac{1}{\alpha_B}}}</script><p>其中 $B_* \sim 2\cdot 10^8 $ tokens,$\alpha_B \sim 0.21$ 是常数</p><p>有了平均值的刻画,我们就不需要积分了,直接代入平均值,可以近似定义:</p><script type="math/tex; mode=display">\begin{aligned}S_{min} = \frac{S}{1 + \frac{B_{critic}(L)}{B}}\end{aligned}</script><p>$S_{min}$ 代表着当 $B \rightarrow \infty$ 时,理论最优的 Steps 数,$S$ 可以看作实际 $B$ 对应的</p><p>通过这个公式计算得到的 $S_{min}$,OpenAI 拿去拟合了下述的 Power Law 公式</p><script type="math/tex; mode=display">L(N, S_{min}) = (\frac{N_c}{N})^{\alpha_N} + (\frac{S_c}{S_{min}})^{\alpha_S}</script><p>效果很不错,如图所示</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAALVrmin_mC3wBcwgiX0EKwbHxOBCHiIAAIvsDEb3alBRXuZicfM7KzUAQADAgADeAADNgQ.png" width="500px" /><p style="font-size: 10px;">搜索并拟合不同参数量下 Loss 和 Smin 的关系</p></center><p>题外话</p><p>上述 $B_{critic} = \frac{E_{min}}{S_{min}}$ 的定义是没问题的,但是 OpenAI 拿来拟合 $B_{critic}$ 的值是用公式</p><script type="math/tex; mode=display">(\frac{S}{S_{min}} - 1) (\frac{E}{E_{min}}-1) = 1</script><p>得到的,但这个公式是不准的,导致得到的 $B_{critic}(L) \approx \frac{B_*}{L^{\frac{1}{\alpha_B}}}$ 也不准,虽然实践中效果确实不差。在我看来,这些拟合的公式都是经验性的,量纲都对不齐,所以没事不要想着代来代去,除非很确定其中每一步的约束与近似。</p><p>还有,大部分推导都是在 SGD 上做的,这就已经和 Adam 有很大的区别了,不必奉为圭臬。当然,有个参考还是很好的。</p><h1 id="Multi-Power-Law"><a href="#Multi-Power-Law" class="headerlink" title="Multi-Power Law"></a>Multi-Power Law</h1><p>写不动了,这篇我也只是听了一下讲解粗看了一下,随便记一点(((</p><p><br></p><p>上述公式刻画了 Loss 与 N,S,B,C,但没有刻画与 learning rate 的关系。下面一篇就填补了这个空缺,并且实践效果比较好。</p><p>假设 $t$ 步时,使用 stable 策略的等效步数为 $Z(t) = \frac{S(t)}{\eta_0}, 其中 S(t) = \sum_{\tau=0}^t \eta_{\tau}$</p><p>作者将 loss 拆解为了两个部分:</p><script type="math/tex; mode=display">L(t) = L_{const}(Z(t)) - (L_{const}(Z(t)) - L(t))</script><p>其中后一项差值定义为 $LD(t) := L_{const}(Z(t)) - L(t)$,而前一项则根据过去经验定义为 $L_{const}(Z(t)) = L_0 + A \cdot (S(t) + S_W)^{-\alpha}$,$S_W$ 为 warmup 阶段的学习率之和</p><p>关于第二项,可以认为是学习率衰减造成的偏差,作者经过简单的拟合,发现线性拟合就还不错</p><script type="math/tex; mode=display">LD(t) \approx C(\eta_0 - \eta_t)</script><p>不过,这个还不准,所以考虑了另一种分析框架。我们定义一些辅助轨迹</p><ul><li><p>$L_0(t)$,用恒定学习率 $\eta_0$ 一直训练下去</p></li><li><p>$L_1(t)$,第1步用 $\eta_0$,之后用恒定学习率 $\eta_1$ 一直训练下去</p></li><li><p>$L_2(t)$,用 $[\eta_0, \eta_1]$,随后用恒定学习率 $\eta_2$ 一直训练下去</p></li></ul><p>我们定义 $S_k(t) = \sum_{\tau=k}^t \eta_{\tau}$,则 $L_k$ 的等效学习步数 $t_k$ 满足</p><script type="math/tex; mode=display">t_k = k-1 + \frac{S_k(t)}{\eta_k}</script><p>对于 $k$ 和 $k+1$ 过程,中间 Loss Reduction 为:</p><script type="math/tex; mode=display">LD_k(t_{k+1}) = L_k(t_k) - L_{k+1}(t_{k+1})</script><p>则通过差分得到最终的 $LD$ 项为</p><script type="math/tex; mode=display">LD(t) = \sum_{k=0}^{t-1} LD_k (t_{k+1})</script><p>作者做了实验发现 $LD$ 会先上升然后到达一个界限,所以猜了一个幂律关系去拟合</p><script type="math/tex; mode=display">LD(T_A + x) = \tilde{B}(1-(\tilde{C}\eta_B x + 1)^{-\beta})</script><p>进一步分析上述的 $\tilde{B}$,$\tilde{C}$,有 $\tilde{B} = B(\eta_A - \eta_B)$,$\tilde{C} = C \eta_B^{-\gamma}$,所以</p><script type="math/tex; mode=display">LD(T_A + x) = B(\eta_A - \eta_B)(1-(C \eta_B^{1-\gamma} x + 1)^{-\beta})</script><p>最终拟合得到的一个有关 learning rate 的 scaling law</p><script type="math/tex; mode=display">L(t) \approx L_0 + A (S_1(t) + S_W)^{-\alpha} - \sum_{k=1}^t B (\eta_{k-1} -\eta_k) (1 - (C\eta_k^{-\gamma} S_k(t) + 1)^{-\beta})</script><p>总的来说,MPL的拟合实际试下来比较准的,但它只能跨learning rate,所以应用范围比较有限。</p><p>写不动了,MPL 有机会再补充吧</p>]]></content>
<summary type="html">scale up !!</summary>
<category term="LLM" scheme="https://anti-entrophic.github.io/categories/LLM/"/>
<category term="LLM" scheme="https://anti-entrophic.github.io/tags/LLM/"/>
<category term="Scaling Law" scheme="https://anti-entrophic.github.io/tags/Scaling-Law/"/>
<category term="Pretrain" scheme="https://anti-entrophic.github.io/tags/Pretrain/"/>
</entry>
<entry>
<title>记一次电脑 0xc0000005 宕机的排查</title>
<link href="https://anti-entrophic.github.io/posts/10048.html"/>
<id>https://anti-entrophic.github.io/posts/10048.html</id>
<published>2025-07-29T05:50:33.000Z</published>
<updated>2025-07-29T07:49:11.912Z</updated>
<content type="html"><![CDATA[<h1 id="Introductory-Chapter"><a href="#Introductory-Chapter" class="headerlink" title="Introductory Chapter"></a>Introductory Chapter</h1><p>望周知,鄙人最近斥巨资配置了一台海景房,摆在宿舍里享受单身生活。😋</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK062iIW9skUVLkCKasA2kG5v339FM_AAKkrzEbmXFJRH_WyIuKP2pcAQADAgADdwADNgQ.jpg" width="200px" /><p style="font-size: 10px;">爽</p></center><p>然而,天有不测风云,从某天开始电脑突然变得神神鬼鬼,关键表现就是不定期会出现和 explorer 的交互卡住,包括呼出任务管理器等功能。</p><p>我一开始也没去管,虽然心里有各种各样的猜测,比如最近有安装了虚拟机 docker 用 cpu 炼小丹,是不是一训几个小时给锻炼坏了;又或者是技嘉的某某驱动又发生兼容性问题了;再或者是公司强行安装的入网安全小助手发生了内存泄漏云云。总的来说,还不至于影响日常使用。</p><p>直到某一天,电脑在抛出了一个错误后彻底宕机了。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK07miIXD5j9pzz4OkGhXZzRy98cwbYAAKnrzEbmXFJRB1fWIk2ojNxAQADAgADeQADNgQ.png" width="400px" /><p style="font-size: 10px;">急</p></center><p>这下好了,不得不直面这一问题了。好在强制重启后电脑仍能正常运行,让我有机会排查发生了什么。</p><h1 id="硬件排查"><a href="#硬件排查" class="headerlink" title="硬件排查"></a>硬件排查</h1><p>首先联网搜索一下这个 <code>0xc0000005</code> 是什么问题,可以搜到是内存相关的报错,并且有一系列无头案例,无法确定问题。</p><p>随后,我决定先从硬件查起,最有可能导致错误的是<strong>SSD</strong>或<strong>内存条</strong>出现了故障。不过老实说,我对自己买的东西品控还是比较放心的,最后也是理所当然的通过排查。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK08WiIXQ4oLVVWoJm0ol2bvt4wLz1XAAKrrzEbmXFJRFFzlSHVU4p0AQADAgADeQADNgQ.png" width="500px" /><p style="font-size: 10px;">赢</p></center><p>内存排查则是通过 windows 自带的内存诊断工具 <code>mdsched.exe</code>。教程提到需要选择<strong>重新启动并检查问题</strong>,重启后在 <strong>事件查看器</strong> 中查看详细报告。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK082iIXXDLF84SDgpnjJpxTUAIWvTSAAJGrjEbJU9ARHif5DcpuT9oAQADAgADeAADNgQ.png" width="400px" /><p style="font-size: 10px;">典</p></center><p>此时,我的电脑已经是菠萝菠萝哒了,有点担心重启后发生更怪的事情,所以缓了一下。不过,事件查看器倒是提醒我了如何排查问题的源头。</p><h1 id="系统日志"><a href="#系统日志" class="headerlink" title="系统日志"></a>系统日志</h1><p>呼出windows事件查看器,在 <code>windows日志 - 系统</code> 里看一下宕机前的记录,能发现在宕机的30分钟前开始频繁地报错 <code>等待 GraphicsPerfSvc 服务的连接超时(30000 毫秒)。</code>,并在我最后一次 <code>Alt + Tab</code> 后切换到桌面的时候彻底宕机,报错 <code>弹出应用程序: 任务切换: explorer.exe - 系统错误: Exception Processing Message 0xc0000005 - Unexpected parameters</code></p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK0_GiIYB8ZutYoVBmGNuJQucDcPzGBAAK2rzEbmXFJRD9G8WSPTfQmAQADAgADdwADNgQ.png" width="600px" /><p style="font-size: 10px;">麻</p></center><p>显然,是图形相关的功能报错了,并产生了一系列连锁反应,首当其冲的嫌疑人就是显卡驱动。不过,我还留心了一下之前电脑卡住的表现,与 GCC(技嘉主板控制中心) 和 小红车 交互时都有卡顿的发生。考虑到这个问题是最近发生的,可能是它们中的某一个带来了问题。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK09WiIXiFXdQ_CR0uVPEOPjzAhDJOOAAKvrzEbmXFJREV1mt-eYOW8AQADAgADeQADNgQ.png" width="400px" /><p style="font-size: 10px;">孝</p></center><p>我依次卸载了 GCC,并把 小红车、GameViewer 这些程序的开机自动启动关闭。在注意到小红车开机自启动旁边的<strong>设置管理员权限</strong>的选项时,突然灵光乍现,我之前在设置 Snipaste 的开机自启动选项时给了它管理员权限!</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK092iIXl6E8Z39Z6BZSqfgmFSupjbCAAKxrzEbmXFJRNtjvV4IlSjmAQADAgADeAADNgQ.png" width="400px" /><p style="font-size: 10px;">乐</p></center><p>进到安装目录正准备卸载,顿时绷不住了,原来我把 Snipaste 装成了 x86 版本。装回正确的 x64 版本后,至今无事发生。</p><h1 id="复盘"><a href="#复盘" class="headerlink" title="复盘"></a>复盘</h1><p>回头再看,这个报错实在是太合理了。我意外地错误安装了 Snipaste 的 32位 版本,而 Snipaste 提供的一些功能(例如全局监控、自动锁定不同窗口等)可能需求注册一些底层hook。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAK0-miIX7LYYAEzRSVZTavHu5WwyUMaAAK0rzEbmXFJRLjmGliNu-sHAQADAgADdwADNgQ.jpg" width="400px" /><p style="font-size: 10px;">蚌</p></center><p>特别是我给了 Snipaste 管理员权限,导致这些 32位 的钩子深入到了不该在的地方,引发了兼容性错误,进一步导致 GraphicsPerfSvc 服务宕机。时间一长,影响了太多核或线程导致卡顿、甚至宕机;或者发生了某次致命的寻址错误,直接一击毙命。</p><p>网络上有各种形形色色的 <code>0xc0000005</code> 的报错反馈,希望能给到其他人一点帮助 —— 如果不是驱动、硬件问题,就要特别留心那些给了管理员权限的软件。</p>]]></content>
<summary type="html">Exception Processing Message 0xc0000005 - Unexpected parameters</summary>
<category term="Windows" scheme="https://anti-entrophic.github.io/categories/Windows/"/>
<category term="Windows" scheme="https://anti-entrophic.github.io/tags/Windows/"/>
</entry>
<entry>
<title>Muon</title>
<link href="https://anti-entrophic.github.io/posts/10047.html"/>
<id>https://anti-entrophic.github.io/posts/10047.html</id>
<published>2025-07-21T08:04:41.000Z</published>
<updated>2025-11-25T13:25:53.306Z</updated>
<content type="html"><![CDATA[<h1 id="Muon-与-AdamW-的对比"><a href="#Muon-与-AdamW-的对比" class="headerlink" title="Muon 与 AdamW 的对比"></a>Muon 与 AdamW 的对比</h1><h2 id="AdamW"><a href="#AdamW" class="headerlink" title="AdamW"></a>AdamW</h2><p>AdamW 我们都很熟悉</p><script type="math/tex; mode=display">\begin{aligned}M_t &= \beta_1 M_{t-1} + (1 - \beta_1) G_t \\V_t &= \beta_2 V_{t-1} + (1 - \beta_2) G_t^2 \\\hat{M_t} &= \frac{M_t}{1 - \beta_1^t} \\ \hat{V_t} &= \frac{V_t}{1 - \beta_2^t} \\\theta_t &= \theta_{t-1} -\eta_t (\frac{\hat{M_t}}{\sqrt{\hat{V_t}} + \epsilon} + \lambda_t \theta_{t-1})\end{aligned}</script><p>若我们仅考虑第一步,则 $\theta_1 = \theta_0 -\eta_t (\frac{G_t}{|G_t| + \epsilon} + \lambda_t \theta_0) \approx \theta_0 -\eta_t [\text{sign}(G_t) + \lambda_t \theta_0]$,只不过后续随着动量累积,行为逐渐变得复杂。</p><p>AdamW 的自适应体现在,它用二阶矩估计来自适应地调整更新步长。它不关心梯度的方向,只是对于梯度较大的步骤,它会强制减小更新幅度;反之它会鼓励较小的梯度更新走的更长。虽然有点玄妙,但已经经过千锤百炼证明稳定性了。</p><h2 id="Lion"><a href="#Lion" class="headerlink" title="Lion"></a>Lion</h2><p>Lion 感觉像是 Alpha Evolve 的雏形。给定 AdamW 的过程,指定几种可能的操作(删除、修改、增加),然后遗传算法开搜。</p><p>当然这不是重点。重要的是 Lion 指出,不需要额外的二阶矩估计来自适应更新幅度,只需要sign归一化即可。</p><script type="math/tex; mode=display">\begin{aligned}u_t &= \text{sign} (\beta_1 M_{t-1} + (1-\beta_1) G_t) \\\theta_t &= \theta_{t-1} - \eta_t (u_t + \lambda_t \theta_{t-1}) \\M_t &= \beta_2 M_{t-1} + (1-\beta_2) G_t \\\end{aligned}</script><p>原论文搜出来的结果是 $\beta_1 = 0.9, \beta_2 = 0.99$,在NLP任务上是 $\beta_1 = 0.95, \beta_2 = 0.98$</p><p>经过sign归一化后,$u$ 每个分量的更新绝对值都是1。为了和AdamW计算的尺度对齐(通常比AdamW的更新值大10倍),因此学习率要缩小10倍以上;为了保持权重衰减的幅度不变,权重衰减就要放大相应的倍数。</p><p>AdamW 上,预训练1B模型时的学习率可以取到 $5 \times 10^{-3}$ 左右,在 Lion 上衰减一下就到 $10^{-4}$ 级别;而 weight decay 也需要相应放大10倍。</p><h3 id="Tiger"><a href="#Tiger" class="headerlink" title="Tiger"></a>Tiger</h3><p>苏神同时做了 $\beta_1 = \beta_2$ 的实验,虽然效果不如Lion,但是尚在可接受范围内。同时发现可以通过简单改动,节省下梯度累积时需要存储历史梯度的开销。感觉比较trivial。</p><h3 id="理解-sign"><a href="#理解-sign" class="headerlink" title="理解 sign"></a>理解 sign</h3><blockquote><p>Lion通过sign 操作平等地对待了每一个分量,使得模型充分地发挥了每一个分量的作用,从而有更好的泛化性能。如果是SGD,那么更新的大小正比于它的梯度,然而有些分量梯度小,可能仅仅是因为它没初始化好,而并非它不重要,所以 Lion 的sign 操作算是为每个参数都提供了“恢复活力”甚至“再创辉煌”的机会。</p></blockquote><p>这里苏神指出,虽然在训练开始阶段考虑泛化比较合理,但如果一个参数的梯度长期较小,那似乎确实可能说明这个参数作用不大。可能可以用别的策略来更多地鼓励重要参数的更新,适当增加其更新幅度。</p><h2 id="Muon"><a href="#Muon" class="headerlink" title="Muon"></a>Muon</h2><p>我读下来,感觉 Muon 最大的特点是,它的归一化不再是element wise的。这样,Lion中遇到的“躺平”参数问题就可以自然而然地解决了,天然地会把更新量分给别的重要参数。</p><p>对于矩阵参数 $W \in \mathbb{R}^{n \times m}$,更新公式如下</p><script type="math/tex; mode=display">\begin{aligned}M_t &= \beta M_{t-1} +(1-\beta) G_t \\W_t &= W_{t-1} - \eta_t [\text{msign}(M_t) + \lambda W_{t-1}]\end{aligned}</script><p>其中,我们设动量矩阵的SVD分解为 $\text{SVD}(M) = U, \Sigma, V^T$,则 $\text{msign}(M) = U_{[:, :r]}V_{[:, :r]}^T$</p><p>易知 $U \in \mathbb{R}^{n \times n}, \Sigma \in \mathbb{R}^{n \times m}, V \in \mathbb{R}^{m \times m}$,$r$ 是矩阵的秩</p><p>也就是说,我们是对动量矩阵 $M_t$ 做奇异值分解,用sign函数来归一化 $M_t$ 的奇异值,在矩阵层面做 sign 归一化。</p><h1 id="Muon的计算"><a href="#Muon的计算" class="headerlink" title="Muon的计算"></a>Muon的计算</h1><h2 id="Newton-schulz迭代"><a href="#Newton-schulz迭代" class="headerlink" title="Newton-schulz迭代"></a>Newton-schulz迭代</h2><p>SVD太复杂,我们需要推导其它的 Muon 的计算形式</p><script type="math/tex; mode=display">\begin{aligned}M &= U \Sigma V^T \\M^TM &= (U \Sigma V^T)^T (U \Sigma V^T) \\&= V \Sigma^T U^T U \Sigma V^T &\cdots \quad U^TU=I\\&= V \Sigma^T \Sigma V^T \\(M^TM)^{\frac{1}{2}} &= V |\Sigma| V^T \\M (M^TM)^{-\frac{1}{2}} &= U \Sigma V^T V |\Sigma|^{-1} V^T \\&= U \text{sign}(\Sigma) V^T \\&= \text{msign}(M)\end{aligned}</script><p>计算复杂度偏高的一步就是求矩阵指数 $(M^TM)^{-\frac{1}{2}}$,我们考虑在 $M^TM = I$ 处泰勒展开 $(M^TM)^{-\frac{1}{2}}$。考虑标量函数 $t^{-\frac{1}{2}}$</p><script type="math/tex; mode=display">t^{-\frac{1}{2}} = 1 - \frac{1}{2}(t-1) + \frac{3}{8}(t-1)^2 - \frac{5}{16}(t-1)^3 + \cdots</script><p>我们保留到二阶,结果是 $t^{-\frac{1}{2}} \approx \frac{15}{8} - \frac{5}{4}t + \frac{3}{8} t^2$,代入矩阵 $M^TM$</p><script type="math/tex; mode=display">\text{msign}(M) = M((M^TM)^{-\frac{1}{2}}) \approx \frac{15}{8} M - \frac{5}{4}M(M^TM) + \frac{3}{8}M(M^TM)^2</script><p>这里我们可以不断把计算得到的 $\text{msign}(M)$ 代入计算,得到更好的 $\text{msign}(M)$</p><div class="note success flat"><p>为什么可以迭代呢?因为 $\text{msign}(M)$ 是一个不动点!已知 $M^TM = I$,代入右式后结果就是 $M$</p><p>稍微深入想一下,对于任意的 $F(M) \rightarrow I$ 的问题,我们都可以通过在 $F(M) = I$ 处泰勒展开来得到一个 $n$ 阶的迭代式,只要起始点不太远就大概率能收敛。真无敌了。 </p></div><h2 id="网络搜索"><a href="#网络搜索" class="headerlink" title="网络搜索"></a>网络搜索</h2><p>我们不从泰勒展开出发,而是直接将 $M_{t+1} = aM_t + bM_t(M_t^TM_t) + cM_t(M_t^TM_t)^2$ 作为一个优化问题,去求解 $a,b,c$</p><p>令 $M_0 = \frac{M}{||M||_F}$,这样不改变 SVD 后得到的 $U,V$,同时可以让 $X_0$ 的所有奇异值在 $[0, 1]$ 之间,更加稳定。 </p><p>更新公式可以表示为:</p><script type="math/tex; mode=display">M_{t+1} = U (a \Sigma + b \Sigma^3 + c\Sigma^5) V^T</script><p>显然,中间是对角阵,我们其实只是在迭代这个对角阵。并且对角阵的幂只是各个对角线元素各自求幂,因此问题可以简化成单个奇异值的迭代。我们只需要输入 初始奇异值 $\sigma$,迭代次数 $T$,在迭代 $g(x) = ax + bx^3 + cx^5$ $T$ 次后拟合最终结果到1即可</p><h3 id="重参数化"><a href="#重参数化" class="headerlink" title="重参数化"></a>重参数化</h3><p>在 $a,b,c$ 的初始值选择上有一个小技巧,重参数化 $g(x) = ax + bx^3 + cx^5 = x + kx(x^2-x_1^2)(x^2-x_2^2)$ </p><p>这样的好处是,可以直观的表示出了迭代的5个不动点 $0, \pm x_1 \pm x_2$,选择 $x_1 < 1, x_2 > 1$,可以让迭代保持在 1 附近。用MSE作损失函数训练即可。</p><div class="note success flat"><p>用随机矩阵作为训练集,是否合理呢?</p><p>2025.11.25 update: 关于这一点,最近训了下NS,有了更深的理解,看看后续有没有机会写篇文章share一下。欢迎催更。</p></div><h1 id="从范数的视角理解-Muon"><a href="#从范数的视角理解-Muon" class="headerlink" title="从范数的视角理解 Muon"></a>从范数的视角理解 Muon</h1><h2 id="理解-SGD"><a href="#理解-SGD" class="headerlink" title="理解 SGD"></a>理解 SGD</h2><p>接下来我们从邻近梯度下降(Proximal Gradient Descent)出发,从更广义的角度理解梯度下降法。为了简便我们以向量为例。</p><p>邻近梯度下降的公式是:</p><script type="math/tex; mode=display">w_{t+1} = \arg \min_w \frac{||w-w_t||^2}{2\eta_t} + \mathcal{L}(w)</script><p>可以直观地理解,一方面我们希望损失 $\mathcal{L}(w)$ 最低;另一方面,我们又不希望 $w_{t+1}$ 离 $w_t$ 太远,以免崩溃。$\eta$ 就是一个调节这个探索范围大小的参数。</p><p>如果 $\eta$ 足够小,可以认为 $\Delta w = w_{t+1} - w_t$ 是很小的,因此我们可以对 $\mathcal{L}(w)$ 做泰勒展开:</p><script type="math/tex; mode=display">w_{t+1} = \arg \min_w \frac{||w-w_t||^2}{2\eta_t} + \mathcal{L}(w_t) + \nabla_{w_t}\mathcal{L}(w_t)^T(w-w_t)</script><p>其中,$\mathcal{L}(w_t)$ 是常量,不影响结果,可以省去。令 $g_t = \nabla_{w_t}\mathcal{L}(w_t)$,我们有:</p><script type="math/tex; mode=display">\Delta w_{t+1} = \arg \min_{\Delta w} \frac{||\Delta w||^2}{2\eta_t} + g_t^T\Delta w</script><p>为了求出这个argmin,我们对 $\Delta w$ 求导。如果 $||\cdot||$ 是 L2 范数,则 $||w||^2 = w^Tw$,有</p><script type="math/tex; mode=display">\nabla f(\Delta w) = \frac{\Delta w}{\eta_t} + g_t = 0</script><p>解得最终的结果竟然就是梯度下降</p><script type="math/tex; mode=display">\begin{aligned}\Delta w &= - \eta_t g_t \\w_{t+1} - w_t &= -\eta_t \nabla_{w_t}\mathcal{L}(w_t)\end{aligned}</script><p>这一推导告诉我们,梯度下降就可以视为 <strong>学习率加权的L2-范数约束下</strong> 的 <strong>将损失函数近似为简单函数</strong> 的邻近梯度下降</p><h2 id="Muon-的范数本质"><a href="#Muon-的范数本质" class="headerlink" title="Muon 的范数本质"></a>Muon 的范数本质</h2><p>回到矩阵 $W \in \mathbb{R}^{m \times n}$,更新规则应为:</p><script type="math/tex; mode=display">\Delta W_{t+1} = \arg \min_{\Delta W} \frac{||\Delta W||^2}{2 \eta_t} + \langle G_t, \Delta W \rangle</script><p>我们选择 Frobenius 内积,$\langle G_t, \Delta W \rangle_F = \text{Tr}(G_t^T\Delta W)$</p><p>将 $\Delta W$ 解耦为范数和方向,设 $\gamma = ||\Delta W||$, $\Phi = - \frac{\Delta W}{||\Delta W||}$,我们得到:</p><script type="math/tex; mode=display">\min_{\Delta W} \frac{||\Delta W||^2}{2 \eta_t} + \text{Tr}(G_t^T\Delta W) = \min_{\gamma \geq 0} \frac{\gamma^2}{2 \eta_t} - \gamma (\max_{||\Phi||=1} \text{Tr}(G_t^T \Phi))</script><p>如果选择 F 范数,则和把矩阵展平后的 L2 范数一样,也就是我们在向量形式中推导的梯度下降</p><p>如果选择谱范数,则有 </p><script type="math/tex; mode=display">||\Phi||_2 = \max_{||x||_2=1}||\Phi x||_2</script><p>设 $G$ 的SVD分解为 $ U \Sigma V^T = \sum_{i=1}^r \sigma_i u_i v_i^T$,我们有</p><script type="math/tex; mode=display">\text{Tr}(G^T\Phi) = \text{Tr}(\sum_{i=1}^r \sigma_i v_i u_i^T\Phi) = \sum_{i=1}^r \sigma_i u_i^T \Phi v_i</script><p>由于约束 $||\Phi||_2=1$,有 $||\Phi v_i||_2 \leq ||v_i||_2 = 1$,于是 $u_i^T \Phi v_i \leq 1$,所以</p><script type="math/tex; mode=display">\text{Tr}(G^T\Phi) \leq \sum_{i=1}^r \sigma_i</script><p>等号在所有 $u_i^T \Phi v_i$ 都等于1时取到,此时</p><script type="math/tex; mode=display">\Phi = \sum_{i=1}^r u_i v_i^T = U_{[:, :r]} V_{[:, :r]}^T = \text{msign}(G)</script><p>简单代回,我们知道:</p><script type="math/tex; mode=display">W_{t+1}=W_t - ||\Delta W|| \text{msign}(G)</script><p>注意这不是原来 求min 的解,只是由约束条件推出的,它能告诉我们的是,Muon就是谱范数约束下的最速下降方向。具体的步长由学习率调控。</p><p>而 谱范数 的约束强于 F范数,恒成立 $||\Phi||_2 \leq ||\Phi||_F$,这可能是暗示 Muon 优越性的一个点。</p><h1 id="Adam参数迁移"><a href="#Adam参数迁移" class="headerlink" title="Adam参数迁移"></a>Adam参数迁移</h1><p>在搜寻 Muon 超参数时,苏神介绍了一种直接从已经调好的AdamW参数迁移的方法。(Kimi应该还是搜了一遍的)</p><p>具体而言,他们观察到 Adam 更新量的 RMS (Root Mean Square) 较为稳定,通常在 0.2 ~ 0.4 之间。RMS 定义为:</p><script type="math/tex; mode=display">\text{RMS}(W) = \frac{||W||_F}{\sqrt{nm}} = \sqrt{\frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^m W_{i,j}^2 }</script><p>因此,希望将Muon的更新量也对齐到这个范围,最终取为0.2</p><script type="math/tex; mode=display">W_t = W_{t-1} - \eta_t (\frac{0.2 M_t}{\text{RMS}(M_t)} + \lambda W_{t-1})</script><p>这样,可以复用 Adam 的参数,令 $\eta_t : = \frac{0.2\eta_t}{\text{RMS}(M_t)}$ 即可</p><p>更进一步,Muon的动量的RMS值是可以显式算出来的:</p><script type="math/tex; mode=display">\begin{aligned}\text{RMS}(M) &= \sqrt{\frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^m M_{i,j}^2 } \\&= \sqrt{\frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^m \sum_{k=1}^r U_{i,k}^2 V_{k,j}^2} \\&= \sqrt{\frac{1}{nm} \sum_{k=1}^r (\sum_{i=1}^n U_{i,k}^2) (\sum_{j=1}^m V_{k,j}^2) } \\&= \sqrt{\frac{1}{nm} \sum_{k=1}^r 1} \\&= \sqrt{\frac{r}{nm}}\end{aligned}</script><p>考虑到随机矩阵严格低秩的概率比较小,这里可以认为 $r = \min(n, m)$,从而有 $\text{RMS}(M) = \sqrt{\frac{1}{\max(n,m)}}$,最终的更新公式就是:</p><script type="math/tex; mode=display">W_t = W_{t-1} - \eta_t (0.2 M_t \sqrt{\max(n,m)} + \lambda W_{t-1})</script><div class="note success flat"><p>从稳定秩的角度考虑,$M = UV^T$ 天然就是满秩,所以取 $r = \min(n, m)$ 完全没问题</p></div><p>比较有意思的一个点就是,这个更新公式也指出了,不同形状的参数需要有不同的学习率。</p><h1 id="参考文献"><a href="#参考文献" class="headerlink" title="参考文献"></a>参考文献</h1><p><a href="https://spaces.ac.cn/archives/10592" title="Muon优化器赏析:从向量到矩阵的本质跨越 by 苏剑林">Muon优化器赏析:从向量到矩阵的本质跨越</a></p><p><a href="https://spaces.ac.cn/archives/10739" title="Muon续集:为什么我们选择尝试Muon? by 苏剑林">Muon续集:为什么我们选择尝试Muon?</a></p>]]></content>
<summary type="html">Muon优化器赏析</summary>
<category term="Optimizer" scheme="https://anti-entrophic.github.io/categories/Optimizer/"/>
<category term="Optimizer" scheme="https://anti-entrophic.github.io/tags/Optimizer/"/>
<category term="Linear Algebra" scheme="https://anti-entrophic.github.io/tags/Linear-Algebra/"/>
</entry>
<entry>
<title>GaLore</title>
<link href="https://anti-entrophic.github.io/posts/10046.html"/>
<id>https://anti-entrophic.github.io/posts/10046.html</id>
<published>2025-06-13T08:26:42.000Z</published>
<updated>2025-07-04T05:22:25.484Z</updated>
<content type="html"><![CDATA[<h1 id="The-problem-with-Lora"><a href="#The-problem-with-Lora" class="headerlink" title="The problem with Lora"></a>The problem with Lora</h1><p>Lora给出了一种非常 Parameter Efficient 的方法,通过更新两个额外的低秩矩阵来sft。</p><script type="math/tex; mode=display">W' = W + BAx</script><p>其中,$B \in \mathcal{R}^{m \times r}$,$A \in \mathcal{R}^{r \times n}$,$r \ll min(m, n)$</p><p>但是,参数矩阵并不满足低秩假设。根据线性代数的知识,我们知道 $rank(AB) \leq min(rank(A), rank(B))$,因此有:</p><script type="math/tex; mode=display">rank(BA) \leq r</script><p>倘若原矩阵 $rank(W) >r$,则 lora 无论如何也无法很好地近似 full parameter 的更新。</p><blockquote><p>对 $rank(AB) \leq min(rank(A), rank(B))$ 的一种证明:</p><p>首先,$rank(A) = dim(Col(A)) = dim(Row(A)) = number(\text{pivots})$</p><p>考虑 $AB$ 的每一列 $AB_{[:, j]}$,我们能发现 $AB_{[:, j]} = \sum_{i=0}^{m-1} B_{[i, j]}A_{[:, i]}$</p><p>即 $AB$ 的每一列都是 $A$ 的列线性组合</p><p>所以 $rank(AB) = dim(Col(AB)) \leq dim(Col(A)) = rank(A)$, 直接转置再应用一下就可以得到完整的结论了。</p></blockquote><p>However,Galore 这篇文章指出,虽然参数矩阵不见得是 low rank 的,不过可以证明,对于一类被称为 Reversible network 的网络结构,它们的梯度矩阵是低秩的。</p><h1 id="Reversible-network"><a href="#Reversible-network" class="headerlink" title="Reversible network"></a>Reversible network</h1><h2 id="Definition"><a href="#Definition" class="headerlink" title="Definition"></a>Definition</h2><p>一个网络 $\mathcal{N}$ 执行映射 $y = \mathcal{N}(x)$ 被称为 Reversible network 当:</p><ul><li>存在 $L(x, W)$ 使得 $y = L(x, W)x$,并且 $g_x = L^\top(x,W)g_y$</li></ul><p>其中 $g_x = \frac{\partial L}{\partial x}$,$g_y = \frac{\partial L}{\partial y}$</p><p>显然,这里的 $L^\top(x, W)$ 就应该是 Jacobian 矩阵,我们只需要算出雅可比矩阵,再代入前向传播中就可以验证它是不是一个reversible network了。</p><p>常见的不带偏置项bias的线性层,ReLU,leaky ReLU等激活函数,是Reversible network,而softmax,self-attention等不是。</p><p>定义这类网络是为了能够更加定量地分析各层梯度的数值表示,而不至于始终把梯度当作黑箱的形式。</p><h2 id="性质"><a href="#性质" class="headerlink" title="性质"></a>性质</h2><ul><li><p>线性性质:$\alpha_1 \mathcal{N}_1(x) + \alpha_2 \mathcal{N}_2(x)$ 仍然是 Reversible network</p></li><li><p>$\mathcal{N}_2(\mathcal{N}_1(x))$ 仍然是 Reversible network</p></li><li><p>$\mathcal{N}(x) = \frac{\partial \mathcal{N}(x)}{\partial x} x$</p></li></ul><p>第三个性质可以展开一点说,我们先说一下它的证明:</p><p>根据链式法则,我们知道 $g_x = (\frac{\partial y}{\partial x})^\top g_y$</p><p>与 $g_x = L^\top(x,W)g_y$ 对比后,我们知道 $(\frac{\partial y}{\partial x})^\top = L^\top(x,W)$,所以由 $y=L(x, W)x$,可得</p><script type="math/tex; mode=display">y = \frac{\partial y}{\partial x}x,\text{即} \mathcal{N}(x) = \frac{\partial \mathcal{N}(x)}{\partial x} x</script><p>这在数学上也是欧拉齐次函数定理的一种形式。欧拉齐次函数定理指出,一个函数 $f(x)$ 是 $m$ 阶齐次函数(即满足 $f(tx) = t^mf(x)$)的充要条件是 $\nabla f(x) \cdot x = mf(x) $。显然,Reversible Network 即是满足 $m=1$ 的一阶齐次函数。</p><p>既然 $\mathcal{N}(x)$ 是一阶齐次函数,那么 $\mathcal{N}(x)$ 满足 $\mathcal{N}(tx) = tN(x), \forall t > 0$,对其两边求导,得到:</p><script type="math/tex; mode=display">\begin{aligned}t \frac{\text{d}\mathcal{N}(tx)}{\text{d} x} &= t \frac{\text{d}\mathcal{N}(x)}{\text{d} x} \\K(tx) &= K(x)\end{aligned}</script><p>所以一阶齐次函数对应的雅可比矩阵 $K(x) = \frac{\partial \mathcal{N}(x)}{\partial x}$ 必然是一个零阶齐次函数。</p><p>零阶齐次性并不意味着 $K(x)$ 与 $x$ 无关,有可能存在导函数阶跃的情况。</p><p>比如说Relu是一个符合要求的 Reversible Network,它的雅可比矩阵是:</p><script type="math/tex; mode=display">H(x)=\left\{\begin{aligned}1, \quad & x>0 \\0, \quad & x<0 \\\end{aligned}\right.</script><p>显然与 $x$ 有关</p><h2 id="Reversible-Network-的梯度"><a href="#Reversible-Network-的梯度" class="headerlink" title="Reversible Network 的梯度"></a>Reversible Network 的梯度</h2><p>设 $K_i$ 是第 $i$ 层的雅可比矩阵,也即 $K_l(x) = \frac{\partial N_l(f_{l-1})}{\partial f_{l-1}}$,所以</p><script type="math/tex; mode=display">\partial \mathcal{N}(x) = \partial \mathcal{N_L}(\mathcal{N_{L-1}(\cdots \mathcal{N_1}(x))}) = K_L(x)K_{L-1}(x)\cdots K_1(x) \partial x</script><p>我们可以展示一下,当损失函数为MSE loss时 $\varphi := \frac{1}{2}||y-f_L||_2^2$</p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= (\mathcal{N}(x)-y)\top\text{d}\mathcal{N}(x) \\&= (\mathcal{N}(x)-y)^\top K_L(x)K_{L-1}(x)\cdots K_{l+1}(x) \text{d}f_l \\\end{aligned}</script><p>因为 $f_l = W_lf_{l-1}$,所以 $\text{d}f_l = \text{d}W_lf_{l-1} + W_l\text{d}f_{l-1}$</p><p>我们的目标是分析到 $W_l$ 的梯度,因此可以忽略后一部分的微分</p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= (\mathcal{N}(x)-y)^\top K_L(x)K_{L-1}(x)\cdots K_L(x) \text{d}W_lf_{l-1} + \text{与d}W_l\text{无关的部分} \\\end{aligned}</script><p>令 $J_l := K_L(x) \cdots K_{l+1}(x)$</p><p>由 $\text{d} \mathcal{N}(x) = K_L(x)K_{L-1}(x)\cdots K_{l+1}(x) \text{d}f_l$ 与 $\mathcal{N}(x) = \frac{\partial \mathcal{N}(x)}{\partial x} x$,可得</p><script type="math/tex; mode=display">\begin{aligned}\mathcal{N}(x) &= K_L(x)K_{L-1}(x)\cdots K_{l+1}(x) f_l \\&= J_l W_l f_{l-1}\end{aligned}</script><p>所以 </p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= (J_l W_l f_{l-1}-y)^\top J_l\text{d}W_lf_{l-1}\end{aligned}</script><p>因为 $\text{d}\varphi$ 是标量,因此我们可以用迹来表示结果。同时利用迹的循环不变性($\text{tr}(ABC) = \text{tr}(BCA) = \text{tr}(CAB)$),我们可以更方便的调整结果的形式。</p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= \text{tr}((J_l W_l f_{l-1}-y)^\top J_l\text{d}W_lf_{l-1}) \\&= \text{tr}(f_{l-1}(J_l W_l f_{l-1}-y)^\top J_l\text{d}W_l) \\\end{aligned}</script><p>而由矩阵梯度的定义,$(G_l)_{ij} = \frac{\partial \varphi}{\partial (W_l)_{ij}}$,有</p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= \sum_i \sum_j (G_l)_{ij} (\text{d}W_l)_{ij} \\\end{aligned}</script><p>这就是 Frobenius 内积,它有一个重要的性质,可以用迹来代替计算</p><script type="math/tex; mode=display">\begin{aligned}\text{d}\varphi &= \langle G_l, \text{d}W_l \rangle_F \\&= \text{tr}(G_l^\top \text{d}W_l)\end{aligned}</script><p>进行比较,我们可以得到:</p><script type="math/tex; mode=display">\begin{aligned}G_l^\top &= f_{l-1}(J_l W_l f_{l-1}-y)^\top J_l \\G_l &= J_l^\top (J_l W_l f_{l-1}-y) f_{l-1}^\top \\&= J_l^\top J_l W_l f_{l-1} f_{l-1}^\top - J_l^\top y f_{l-1}^\top\end{aligned}</script><p>考虑到梯度 $G_l$ 的方向是向上的,而常说的梯度下降的方向是相反的,因此我们记 $\hat{G_l} = -G_l = J_l^\top y f_{l-1}^\top - J_l^\top J_l W_l f_{l-1} f_{l-1}^\top$ 表示真实的梯度下降的方向。</p><p>其它损失函数也一样,变化一下最后一层的梯度就可以了。</p><div class="note success flat"><p>突然发现 MSE loss 和 CE loss 的梯度都是 $\mathcal{N}(x)-y$ 啊</p></div><p>作者也推导了softmax层的结果,得到的结果是 $\hat{G_l} = (J_lP_1^{\bot}y-\gamma K^{-1}J_l^\top P_1^{\bot}J_lW_lf_{l-1})f_{l-1}^\top$,其中 $P_{1}^\bot := I - \frac{1}{K} 1 1^\top$,$K$ 是softmax的维度</p><p>总而言之,言而总之,除了 Attention 层不满足 Reverse Network 的定义(Attention层有 $\mathcal{N}_1(x) \cdot \mathcal{N}_2(x)$ 这样的积性操作),若我们暂且不考虑 Attention 层,则剩下的所有层的梯度似乎都可以表示为统一形式:</p><script type="math/tex; mode=display">\hat{G_l} = A - BW_lC</script><h3 id="Positive-Semi-definite"><a href="#Positive-Semi-definite" class="headerlink" title="Positive Semi-definite"></a>Positive Semi-definite</h3><p>对于 $\hat{G_l} = J_l^\top y f_{l-1}^\top - J_l^\top J_l W_l f_{l-1} f_{l-1}^\top$ 的形式</p><script type="math/tex; mode=display">\left\{\begin{aligned}B &= J_l^\top J_l \\C &= f_{l-1} f_{l-1}^\top\end{aligned}\right.</script><p>而根据线性代数的知识,我们知道这其实就是 $B$ 和 $C$ 都是半正定(Positive Semi-definite, PSD)矩阵的充要条件,因为 $x^\top B^\top B x = (Bx)^\top Bx = ||Bx||^2 \geq 0$</p><p>而到了softmax层的结果,我们只需额外证明 $J_l^\top P_1^{\bot}J_l$ 是一个半正定矩阵。根据合同变换,我们知道如果一个矩阵 $A$ 是半正定矩阵,那对任意矩阵 $P$,$P^\top AP$ 也是半正定的。所以我们只要证明 $P_1^{\bot}$ 是一个PSD即可</p><p>我们先来求 $Y = \mathbf{1}\mathbf{1}^T$ 的特征值。因为这是一个全1矩阵,也就意味着秩为1,因此它必然有 $K-1$ 个数值为0的特征值。</p><p>剩下一个我们简单地根据求特征值定义 $Ax = \lambda x$,能求得最后一个特征值为 $K$ </p><p>我们要求 $P = I - \frac{1}{K}Y$ 的特征值,根据定义:</p><script type="math/tex; mode=display">Pv = (I-\frac{1}{K}Y)v = Iv - \frac{1}{K}(Yv) = v - \frac{1}{K}(\lambda_Y v) = (1-\frac{\lambda_Y}{K})v</script><p>所以 $P$ 和 $Y$ 的特征值存在一个简单的关系: $\lambda_P = (1-\frac{\lambda_Y}{K})$,代入可求得 $P$ 的特征值有1个0和 $K-1$ 个1</p><p>所有特征值非负,所以 $P$ 是一个半正定矩阵。</p><p>综上,我们讨论的Reversible Network中的梯度,表达式中的 $B$ 和 $C$ 总是半正定的</p><h1 id="Gradient-becomes-low-rank-during-training"><a href="#Gradient-becomes-low-rank-during-training" class="headerlink" title="Gradient becomes low-rank during training"></a>Gradient becomes low-rank during training</h1><h2 id="引理"><a href="#引理" class="headerlink" title="引理"></a>引理</h2><p>若我们假设矩阵梯度具有如下形式:</p><script type="math/tex; mode=display">G_t = \frac{1}{N} \sum_{i=1}^N(A_i - B_iW_tC_i)</script><p>其中 $N$ 表示batch size,$B$ 和 $C$ 都是半正定矩阵。考虑梯度更新公式为最简单的SGD梯度下降:$W_t = W_{t-1} + \eta G_{t-1}$</p><p>令 $S:=\frac{1}{N}\sum_{i=1}^N C_i^T \otimes B_i$,且 $\lambda_1 < \lambda_2$ 是它最小的两个不同的特征值。则稳定秩 $sr(G_t)$ 满足:</p><script type="math/tex; mode=display">sr(G_t) \leq sr(G_{t_0}^{||}) + (\frac{1-\eta \lambda_2}{1 - \eta \lambda_1})^{2(t-t_0)} \frac{||G_0 - G_{t_0}^|||_F^2}{||G_{t_0}^{||}||_2^2}</script><p>其中 $G_{t_0}^{||}$ 是 $G_{t_0}$ 投影到 $S$ 的最小特征值 $\lambda_1$ 对应的特征空间 $\mathcal{V}_1$</p><h2 id="稳定秩"><a href="#稳定秩" class="headerlink" title="稳定秩"></a>稳定秩</h2><p>稳定秩的意义在于,对于</p><script type="math/tex; mode=display">\left(\begin{array}{l}1 & 0 \\0 & 1e^{-5}\end{array}\right)</script><p>这样的矩阵,它的秩是2,但是实际上,第二个特征值非常小,这个特征几乎没有用,我们希望能有一种这种忽略极小值的计算矩阵秩的方法。</p><p>稳定秩的定义为F-范数与2-范数商的平方:</p><script type="math/tex; mode=display">sr(G) = \frac{||A||_F^2}{||A||_2^2}</script><p>F-范数是矩阵所有元素平方和的平方根,2-范数是矩阵的最大奇异值。可以简单理解为一种加权秩,比较的是其它奇异值与最大奇异值的差距。</p><h2 id="证明"><a href="#证明" class="headerlink" title="证明"></a>证明</h2><p>我们有</p><script type="math/tex; mode=display">G_t = \frac{1}{N} \sum_{i=1}^{N} (A_i -B_iW_t C_i) = \frac{1}{N} \sum_{i=1}^{N} A_i - B_i(W_{t-1} + \eta G_{t-1})C_i = G_{t-1} - \frac{\eta}{N}\sum_{i=1}^{N}B_iG_{t-1}C_i</script><p>令 $S := \frac{1}{N}\sum_{i=1}^{N}C_i^T \otimes B_i$,并且 $g_t := \text{vec}(G_t) \in \mathbb{R}^{mn}$</p><p>根据夹心公式 $\text{vec}(BWC) = (C^T \otimes B)\text{vec}(W)$,我们有</p><script type="math/tex; mode=display">\begin{aligned}\text{vec}(G_t) &= g_t = g_{t-1} - \frac{\eta}{N}\sum_{i=1}^NC_i^T \otimes B_i \text{vec}(G_{t-1}) \\&= g_{t-1} - \eta S g_{t-1} = (1 - \eta S)g_{t-1} \\g_t &= (1-\eta S)^t g_0\end{aligned}</script><p>我们考虑将 $g_0$ 分解为 $g_0 = g_0^{||} + g_0^\bot$ 上。$g_0^{||}$ 在 $S$ 的最小特征值 $\lambda_1$ 对应的 $\mathcal{K}$ 维(几何重数)特征子空间 $\mathcal{V}_1$ 上 ,$g_0^\bot$ 在 $\mathcal{V}_1$ 的正交补空间中。(这里不知道为什么,原文说正交补空间是个不变子空间)</p><script type="math/tex; mode=display">\begin{aligned}||G_t||_F^2 &= ||g_t||_2^2 = ||(I-\eta S)^t g_0||_2^2 = ||(I-\eta S)^t g_0^{||}||_2^2 + ||(I-\eta S)^t g_0^{\bot}||_2^2 \\ &= \end{aligned}</script><p>(未完待续)</p>]]></content>
<summary type="html">终于有空写点东西</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Optimizer" scheme="https://anti-entrophic.github.io/tags/Optimizer/"/>
<category term="Linear Algebra" scheme="https://anti-entrophic.github.io/tags/Linear-Algebra/"/>
</entry>
<entry>
<title>组合数学(一)容斥原理、二项式反演与第二类斯特林数</title>
<link href="https://anti-entrophic.github.io/posts/10045.html"/>
<id>https://anti-entrophic.github.io/posts/10045.html</id>
<published>2025-05-16T14:14:28.000Z</published>
<updated>2025-05-16T14:14:46.344Z</updated>
<content type="html"><![CDATA[<h1 id="二项式反演"><a href="#二项式反演" class="headerlink" title="二项式反演"></a>二项式反演</h1><p>记 $f_n$ 表示恰好使用 $n$ 个不同元素形成特定结构的方案数,$g_n$ 表示从 $n$ 个不同元素中选出 $i\geq 0$ 个元素形成特定结构的总方案数。</p><p>若已知 $f_n$ 求 $g_n$,那么显然有:</p><script type="math/tex; mode=display">g_n = \sum_{i=0}^{n} \binom{n}{i}f_i</script><p>若已知 $g_n$ 求 $f_n$,则被称为 <strong>二项式反演</strong>,公式为:</p><script type="math/tex; mode=display">f_n = \sum_{i=0}^{n} \binom{n}{i} (-1)^{n-i} g_i</script><h2 id="推导"><a href="#推导" class="headerlink" title="推导"></a>推导</h2><p>对右式代入 $g_n = \sum_{i=0}^{n} \binom{n}{i}f_i$,得到</p><script type="math/tex; mode=display">\begin{aligned}\sum_{i=0}^{n} \binom{n}{i} (-1)^{n-i} g_i &= \sum_{i=0}^{n} \binom{n}{i} (-1)^{n-i} [\sum_{j=0}^i \binom{i}{j}f_j] \\&= \sum_{i=0}^{n} \sum_{j=0}^i \binom{n}{i} \binom{i}{j} (-1)^{n-i} f_j\end{aligned}</script><p>交换 $i$ 和 $j$ 的枚举顺序,得到</p><script type="math/tex; mode=display">\begin{aligned}\sum_{i=0}^{n} \sum_{j=0}^i \binom{n}{i} \binom{i}{j} (-1)^{n-i} f_j &= \sum_{j=0}^{n}\sum_{i=j}^n \binom{n}{i} \binom{i}{j} (-1)^{n-i} f_j \\&= \sum_{j=0}^{n} f_j \sum_{i=j}^{n} \binom{n}{i} \binom{i}{j} (-1)^{n-i}\end{aligned}</script><p>由于 $\binom{n}{i} \binom{i}{j} = \binom{n}{j} \binom{n-j}{i-j}$。这个很好理解,相当于是 $n$ 里取 $i$ 个,再 $i$ 里取 $j$ 个。现在是直接取 $j$ 个,然后从剩下的 $n-j$ 个里取剩下的 $i-j$ 个。总之得到</p><script type="math/tex; mode=display">\begin{aligned}\sum_{j=0}^{n} f_j \sum_{i=j}^{n} \binom{n}{i} \binom{i}{j} (-1)^{n-i} &= \sum_{j=0}^{n} f_j \sum_{i=j}^{n} \binom{n}{j} \binom{n-j}{i-j} (-1)^{n-i} \\ &= \sum_{j=0}^{n} \binom{n}{j} f_j \sum_{i=j}^{n} \binom{n-j}{i-j} (-1)^{n-i}\end{aligned}</script><p>令 $k = i - j$,则 $i = k+ j$,上式转换为:</p><script type="math/tex; mode=display">\begin{aligned}\sum_{j=0}^{n} \binom{n}{j} f_j \sum_{i=j}^{n} \binom{n-j}{i-j} (-1)^{n-i} &= \sum_{j=0}^{n} \binom{n}{j} f_j \sum_{k=0}^{n-j} \binom{n-j}{k} (-1)^{n-j-k}1^k \\&= \sum_{j=0}^n \binom{n}{j} f_j (-1+1)^{n-j}\end{aligned}</script><p>当且仅当 $n=j$ 时不为 $0$</p><script type="math/tex; mode=display">\sum_{j=0}^n \binom{n}{j} f_j (-1+1)^{n-j} = f_j</script><p>证毕</p><h1 id="第二类斯特林数(Stirling-Number)"><a href="#第二类斯特林数(Stirling-Number)" class="headerlink" title="第二类斯特林数(Stirling Number)"></a>第二类斯特林数(Stirling Number)</h1><p>第二类斯特林数 $S(n,k)$ 表示将 $n$ 个两两不同的元素,划分为 $k$ 个互不区分的非空子集的方案数。</p><h2 id="递推式"><a href="#递推式" class="headerlink" title="递推式"></a>递推式</h2><script type="math/tex; mode=display">S(n, k) = S(n-1, k-1) + kS(n-1, k)</script><p>考虑用组合意义来证明</p><p>当我们插入一个新元素时,有两种方案:</p><ul><li><p>将新元素单独放入一个子集,有 $S(n-1, k-1)$ 种方案</p></li><li><p>将新元素放入一个现有的非空子集,有 $kS(n-1, k)$ 种方案</p></li></ul><p>相加即得</p><h2 id="通项公式"><a href="#通项公式" class="headerlink" title="通项公式"></a>通项公式</h2><script type="math/tex; mode=display">S(n, m) = \sum_{i=0}^m \frac{(-1)^{m-i}i^n}{i!(m-i)!}</script><p>使用容斥原理证明该公式。设将 $n$ 个两两不同的元素,划分到 $i$ 个两两不同的集合(允许空集)的方案数为 $G_i$, 将 $n$ 个两两不同的元素,划分到 $i$ 个两两不同的非空集合(不允许空集)的方案数为 $F_i$</p><p>根据定义,有</p><script type="math/tex; mode=display">\begin{aligned}G_i &= i^n \\G_i &= \sum_{j=0}^i \binom{i}{j} F_j\end{aligned}</script><p>根据二项式反演,有:</p><script type="math/tex; mode=display">\begin{aligned}F_i &= \sum_{j=0}^i (-1)^{i-j} \binom{i}{j} G_j \\ &= \sum_{j=0}^i (-1)^{i-j} \binom{i}{j} j^n \\&= \sum_{j=0}^i \frac{i!(-1)^{i-j}j^n}{j!(i-j)!}\end{aligned}</script><p>而第二类斯特林数要求的集合之间是互不区分的,因此 $F_i$ 是 $S(n, i)$ 的 $i!$ 倍,所以:</p><script type="math/tex; mode=display">S(n, m) = \frac{F_m}{m!} = \sum_{i=0}^m \frac{(-1)^{m-i}i^n}{i!(m-i)!}</script>]]></content>
<summary type="html">为什么我还在复习一年级的东西,玉玉了</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
</entry>
<entry>
<title>Gamma Function</title>
<link href="https://anti-entrophic.github.io/posts/10044.html"/>
<id>https://anti-entrophic.github.io/posts/10044.html</id>
<published>2025-05-16T07:49:24.000Z</published>
<updated>2025-05-16T14:07:58.609Z</updated>
<content type="html"><![CDATA[<h1 id="伽马函数"><a href="#伽马函数" class="headerlink" title="伽马函数"></a>伽马函数</h1><p>定义为:</p><script type="math/tex; mode=display">\Gamma(x) = \int_0^\infty t^{x-1}e^{-t}dt</script><p>对应正整数 $n$,满足 $\Gamma(n)=(n-1)!$</p><h1 id="递推关系式"><a href="#递推关系式" class="headerlink" title="递推关系式"></a>递推关系式</h1><p>伽马函数存在递推关系式</p><script type="math/tex; mode=display">\Gamma(x+1) = x \Gamma(x)</script><p>可以简单地使用分部积分法证明:</p><script type="math/tex; mode=display">\begin{aligned}\Gamma(x+1) &= \int_0^\infty t^{x}e^{-t}dt = \int_0^\infty t^{x}d(-e^{-t}) \\ &= [-t^xe^{-t}]_0^{\infty} - \int_0^\infty(-e^{-t})dt^x \\ &= 0 + x\int_0^\infty(e^{-t})t^{x-1}dt \\&= x\Gamma(x)\end{aligned}</script><p>而 $\Gamma(1) = 1$,所以 $\Gamma(n) = (n-1)!$</p><h1 id="计算"><a href="#计算" class="headerlink" title="计算"></a>计算</h1><p>可以直接调 <code>math.gamma()</code> 计算</p><p>由于阶乘增长速度太快,所以可以采用 <code>math.lgamma()</code>,计算对数值增强稳定性</p><h1 id="反射公式"><a href="#反射公式" class="headerlink" title="反射公式"></a>反射公式</h1><script type="math/tex; mode=display">\Gamma(x)\Gamma(1-x) = \frac{x}{\sin(\pi x)}</script><p>这个公式允许我们计算负数的伽马函数值。同时我们也能看到,伽马函数在 $0, -1, -2$ 等处附近是发散的。</p>]]></content>
<summary type="html">阶乘在实数域与复数域的扩展</summary>
<category term="Math" scheme="https://anti-entrophic.github.io/categories/Math/"/>
<category term="Math" scheme="https://anti-entrophic.github.io/tags/Math/"/>
</entry>
<entry>
<title>Part IV of Mathematical Structure of Mamba - Mamba&Mamba2</title>
<link href="https://anti-entrophic.github.io/posts/10043.html"/>
<id>https://anti-entrophic.github.io/posts/10043.html</id>
<published>2025-05-15T14:59:43.000Z</published>
<updated>2025-05-15T15:11:27.930Z</updated>
<content type="html"><![CDATA[<div class="note success flat"><p>本篇是mamba系列blog的第四篇文章,系列文章见:</p><ul><li><p><a href="https://anti-entrophic.github.io/posts/10038.html" title="Part I of Mathematical Structure of Mamba - Hippo">Part I of Mathematical Structure of Mamba - Hippo</a></p></li><li><p><a href="https://anti-entrophic.github.io/posts/10039.html" title="Part II of Mathematical Structure of Mamba - S4">Part II of Mathematical Structure of Mamba - S4</a></p></li><li><p><a href="https://anti-entrophic.github.io/posts/10040.html" title="Part III of Mathematical Structure of Mamba - S4D">Part III of Mathematical Structure of Mamba - S4D</a></p></li><li><p>Part IV of Mathematical Structure of Mamba - Mamba&Mamba2</p></li></ul><p>剩余预计还有一篇文章正在生产中~</p></div><h1 id="计算公式"><a href="#计算公式" class="headerlink" title="计算公式"></a>计算公式</h1><p>我们都知道SSM的公式:</p><script type="math/tex; mode=display">\begin{aligned}h_t &= A_th_{t-1} + B_tx_t \\y_t &= C_t^Th_t\end{aligned}</script><p>这里的 $t$ 其实就代表了 <code>seq_len</code> 这一维。我们将公式展开,就可以得到一个卷积的形式</p><script type="math/tex; mode=display">\begin{aligned}h_0 &= B_0x_0 \\h_1 &= A_1h_0 + B_1x_1 = A_1B_0x_0 + B_1x_1 \\h_2 &= A_2A_1B_0x_0 + A_2B_1x_1 +B_2x_2 \\h_t &= \sum_{s=0}^t A_tA_{t-1}...A_{s+1}B_sx_s\end{aligned}</script><p>我们记 $A_{t:s} := A_tA_{t-1}…A_{s+1}$,则最终的输出 $y_t = C_t^Th_t = C_t^T\sum_{s=0}^t A_{t:s}B_sx_s$。</p><p>如果写成向量形式,则有 $y=Mx$,$M$ 有如下格式:</p><script type="math/tex; mode=display">\left[\begin{array}{ll|ll}{C_0^TB_0}&{}&{}&{}\\{C_1^TA_1B_0}&{C_1^TB_1}&{}&{}\\\hline{C_2^TA_2A_1B_0}&{C_2^TA_2B_1}&{C_2^TB_2}&{}\\{C_3^TA_3A_2A_1B_0}&{C_3^TA_3A_2B_1}&{C_3^TA_3B_3}&{C_3^TB_3}\\\hline{\cdots}&{\cdots}&{\cdots}&{\cdots}\end{array}\right]</script><p>其中,分为对角块与方阵。其中的方阵部分还可以进一步简化:</p><script type="math/tex; mode=display">\left[\begin{array}{ll}{C_2^TA_2A_1B_0}&{C_2^TA_2B_1} \\{C_3^TA_3A_2A_1B_0}&{C_3^TA_3A_2B_1}\end{array}\right] = \left[\begin{array}{l}{C_2^T} \\{C_3^TA_3}\end{array}\right] A_2 \left[\begin{array}{ll}{A_1B_0} & {B_1}\end{array}\right]</script><p>更一般的形式为,对于矩阵 $M_{j:j’,i’:i}$,其中 $j’ > j \geq i > i’$:</p><script type="math/tex; mode=display">\left[\begin{array}{ccc}{C_j^TA_{j:i'}B_{i'}}&{\cdots}&{C_j^TA_{j:i-1}B_{i-1}} \\{\vdots} & {\ddots} & {\vdots} \\{C_{j'-1}^TA_{j'-1:i'}B_{i'}}&{\cdots}&{C_{j'-1}^TA_{j'-1:i-1}B_{i-1}}\end{array}\right] = \left[\begin{array}{c}{C_j^TA_{j:j}} \\{\vdots} \\{C_{j'-1}^TA_{j'-1:j}}\end{array}\right] A_{j:i-1} \left[\begin{array}{ccc}{A_{i-1:i'}B_{i'}} & {\cdots} & {A_{i-1:i-1}B_{i-1}}\end{array}\right]</script><p>后续的工作,就是写算子把这些东西全部高效地算出来。整体思路就是,按上述分块,先块内算,再块间算。</p><h1 id="Triton-算子"><a href="#Triton-算子" class="headerlink" title="Triton 算子"></a>Triton 算子</h1><p>有关 Triton 的基本知识欢迎参考这篇<a href="https://anti-entrophic.github.io/posts/10042.html" title="Triton Tutorial">Triton Tutorial</a></p><h2 id="chunk-cumsum-fwd"><a href="#chunk-cumsum-fwd" class="headerlink" title="_chunk_cumsum_fwd"></a>_chunk_cumsum_fwd</h2><p>需要先介绍一下 $A$ 的离散化(记 $\tilde{A} $ 为离散化之前的矩阵),见 <a href="https://anti-entrophic.github.io/posts/10040.html" title="Part III of Mathematical Structure of Mamba - S4D">Part III of Mathematical Structure of Mamba - S4D</a></p><p>我们知道了,为了保证 $A$ 的实部总为负,我们需要用指数形式来处理离散化:$e^{\Delta A}$。而指数形式的乘积实际上就对应了指数的累计和,这也是这个算子在做的事情。</p><p>经过 DSS 与 S4D 这两篇文章的沉淀,作者发现A使用对角线比起NPLR也没差很多。特别是从mamba开始抛弃HiPPO后,就更没有使用NPLR的必要了,直接快进到对角线。</p><p>$dt: (B,S,H)$, $A: (H,)$,可以视作每个头都有一个控制状态衰减的变量。</p><p>这个算子做的事情其实比较简单,就是计算 $e^{\sum \Delta_i A_i}$,不过是分块做的。</p><h3 id="grid"><a href="#grid" class="headerlink" title="grid"></a>grid</h3><p>我们把 $dt$ 按照seqlen这一维切开,切成 chunk_size 大小的块。同时,head这一维也会切分,每份大小是 <code>BLOCK_SIZE_H</code>,这个参数后续会用autotune去搜。最后得到一个三维的grid,代表了 <code>[batch, seqlen, heads]</code></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">_chunk_cumsum_fwd</span>(<span class="params">dt, A, chunk_size, dt_bias=<span class="literal">None</span>, dt_softplus=<span class="literal">False</span>, dt_limit=(<span class="params"><span class="number">0.0</span>, <span class="built_in">float</span>(<span class="params"><span class="string">"inf"</span></span>)</span>)</span>):</span><br><span class="line"> batch, seqlen, nheads = dt.shape</span><br><span class="line"> <span class="keyword">assert</span> A.shape == (nheads,)</span><br><span class="line"> <span class="keyword">if</span> dt_bias <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">assert</span> dt_bias.shape == (nheads,)</span><br><span class="line"> nchunks = math.ceil(seqlen / chunk_size)</span><br><span class="line"> dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)</span><br><span class="line"> dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)</span><br><span class="line"> grid_chunk_cs = <span class="keyword">lambda</span> META: (batch, nchunks, triton.cdiv(nheads, META[<span class="string">'BLOCK_SIZE_H'</span>]))</span><br><span class="line"> <span class="keyword">with</span> torch.cuda.device(dt.device.index):</span><br><span class="line"> _chunk_cumsum_fwd_kernel[grid_chunk_cs](</span><br><span class="line"> dt, A, dt_bias, dt_out, dA_cumsum,</span><br><span class="line"> batch, seqlen, nheads, chunk_size,</span><br><span class="line"> dt_limit[<span class="number">0</span>], dt_limit[<span class="number">1</span>],</span><br><span class="line"> dt.stride(<span class="number">0</span>), dt.stride(<span class="number">1</span>), dt.stride(<span class="number">2</span>),</span><br><span class="line"> A.stride(<span class="number">0</span>),</span><br><span class="line"> dt_bias.stride(<span class="number">0</span>) <span class="keyword">if</span> dt_bias <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span> <span class="keyword">else</span> <span class="number">0</span>,</span><br><span class="line"> dt_out.stride(<span class="number">0</span>), dt_out.stride(<span class="number">2</span>), dt_out.stride(<span class="number">1</span>), dt_out.stride(<span class="number">3</span>),</span><br><span class="line"> dA_cumsum.stride(<span class="number">0</span>), dA_cumsum.stride(<span class="number">2</span>), dA_cumsum.stride(<span class="number">1</span>), dA_cumsum.stride(<span class="number">3</span>),</span><br><span class="line"> dt_softplus,</span><br><span class="line"> HAS_DT_BIAS=dt_bias <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>,</span><br><span class="line"> BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),</span><br><span class="line"> )</span><br><span class="line"> <span class="keyword">return</span> dA_cumsum, dt_out</span><br></pre></td></tr></table></figure><p>之后简单看一下整个triton算子的逻辑,我会以 <code>tl.load()</code> 为基础,看每份数据是怎么读进来的;然后再介绍它们之间如何计算。</p><h3 id="dt"><a href="#dt" class="headerlink" title="dt"></a>dt</h3><p>因为每块数据处理第 <code>pid_b</code> 条的 <code>chunk_size</code> 条数据,所以首先定位到 <code>dt_ptr</code>,再取下大小为 <code>[BLOCK_SIZE_H, BLOCK_SIZE_CHUNK]</code> 的一块。</p><p>这里比较需要注意的就是,在读入数据的时候交换了 $S$ 和 $H$ 这两维,是为了后续计算方便。只要stride是正确的话,读进来的数据是一定正确的。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen</span><br><span class="line">offs_h = pid_h * BLOCK_SIZE_H + tl.arange(<span class="number">0</span>, BLOCK_SIZE_H)</span><br><span class="line">offs_c = tl.arange(<span class="number">0</span>, BLOCK_SIZE_CHUNK)</span><br><span class="line">dt_ptrs = dt_ptr + (offs_h[:, <span class="literal">None</span>] * stride_dt_head + offs_c[<span class="literal">None</span>, :] * stride_dt_seqlen)</span><br><span class="line">dt = tl.load(dt_ptrs, mask=(offs_h[:, <span class="literal">None</span>] < nheads) & (offs_c[<span class="literal">None</span>, :] < chunk_size_limit), other=<span class="number">0.0</span>).to(tl.float32)</span><br></pre></td></tr></table></figure><h3 id="A"><a href="#A" class="headerlink" title="A"></a>A</h3><p><code>A</code> 就没啥好说的,就一个维度,分成 <code>BLOCK_SIZE_H</code> 大小即可</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">A_ptrs = A_ptr + offs_h * stride_A_head</span><br><span class="line">A = tl.load(A_ptrs, mask=offs_h < nheads, other=<span class="number">0.0</span>).to(tl.float32)</span><br></pre></td></tr></table></figure><h3 id="计算"><a href="#计算" class="headerlink" title="计算"></a>计算</h3><p>主要就是计算 $dt * A$,然后求累计和</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">dA = dt * A[:, <span class="literal">None</span>]</span><br><span class="line">dA_cs = tl.cumsum(dA, axis=<span class="number">1</span>)</span><br></pre></td></tr></table></figure><p>值得一提的是,最后存下来的结果的维度是 <code>[batch, nheads, nchunks, chunk_size]</code>,等于是把seqlen这一维给拆开了。并且,块间的交互也还没做,比如说本来累计和应该是 [1,2,3,4],现在的分块结果可能是 [1,2,1,2],后续还需要一步块间RNN传递。</p><h2 id="chunk-state-fwd"><a href="#chunk-state-fwd" class="headerlink" title="_chunk_state_fwd"></a>_chunk_state_fwd</h2><p>这段代码主要是算states的,对应的是原计算公式中的</p><script type="math/tex; mode=display">{\left[\begin{array}{c}{B_{i'}^TA_{i-1:i'}} \\ {\vdots} \\ {B_{i-1}^TA_{i-1:i-1}}\end{array}\right]}^T \left[\begin{array}{c}{x_{i'}} \\ {\vdots} \\ {x_{i-1}}\end{array}\right] = \sum_{t=i'}^{i-1}{A_{i-1:i'}^TB_{i'}x_{i'}}</script><p>先明确一下各输入的维度。</p><p><code>x: [batch, seq_len, nheads, headdim]</code>,其中 <code>nheads * headdim = d_model</code></p><p><code>dt: [batch, nheads, nchunks, chunk_size]</code>,其中 <code>nchunks * chunk_size = seq_len</code></p><p><code>B: [batch, seq_len, ngroups, dstate]</code>,其中 <code>ngroups</code> 是为了tp的参数,一般为1;<code>dstate</code> 就是状态向量的维度</p><p><code>dA_cumsum</code> 和 <code>dt</code> 一样</p><h3 id="grid-1"><a href="#grid-1" class="headerlink" title="grid"></a>grid</h3><p>grid的分法看起来有点奇怪,第0维把 <code>headdim * dstate</code> 分成了一组,这是因为 $ABx$ 的结果就是 <code>headdim * dstate</code>,可以这么理解,表示输入 <code>x</code> 的某个维度,对内部状态 <code>h</code> 的某个维度的影响。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># mamba_ssm/ops/triton/ssd_chunk_state.py</span></span><br><span class="line">grid = <span class="keyword">lambda</span> META: (triton.cdiv(headdim, META[<span class="string">'BLOCK_SIZE_M'</span>]) * triton.cdiv(dstate, META[<span class="string">'BLOCK_SIZE_N'</span>]),</span><br><span class="line"> batch * nchunks, nheads)</span><br><span class="line"></span><br><span class="line">num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)</span><br><span class="line">pid_m = tl.program_id(axis=<span class="number">0</span>) // num_pid_n</span><br><span class="line">pid_n = tl.program_id(axis=<span class="number">0</span>) % num_pid_n</span><br></pre></td></tr></table></figure><p>第1维是把 <code>batch</code> 和 <code>nchunks</code> 组合到一起,然后到kernel里之后又光速分开了。看起来排序方式是 <code>[nchunks, batch]</code>,不懂为啥要拼起来传进去。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">pid_bc = tl.program_id(axis=<span class="number">1</span>)</span><br><span class="line">pid_c = pid_bc // batch</span><br><span class="line">pid_b = pid_bc - pid_c * batch</span><br></pre></td></tr></table></figure><p>第2维就简明一点,直接对 <code>nheads</code> 分块</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">pid_h = tl.program_id(axis=<span class="number">2</span>)</span><br></pre></td></tr></table></figure><h3 id="dA-cs-last"><a href="#dA-cs-last" class="headerlink" title="dA_cs_last"></a>dA_cs_last</h3><p>之后依然是读取数据,我们还是以 <code>tl.load()</code> 为支点读代码。首先是 <code>dA_cumsum</code>,也就是一个 <code>chunk_size</code> 大小中的前缀和 $\sum_{i=0}^t \Delta_iA_i$。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head</span><br><span class="line">dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - <span class="number">1</span>) * stride_dA_cs_csize).to(tl.float32)</span><br></pre></td></tr></table></figure><p><code>dA_cumsum</code> 的维度是 <code>[batch, nheads, nchunks, chunk_size]</code>。我们的读取方式很简单,分别朝第0、1、2维移动对应长度。注意,<code>dA_cs_last</code> 读的是这个chunk中的最后一个元素,因此叫last。</p><h3 id="seq-idx"><a href="#seq-idx" class="headerlink" title="seq_idx"></a>seq_idx</h3><p><code>seq_idx</code> 暂时跳过,我不知道这个是干什么的。</p><h3 id="x"><a href="#x" class="headerlink" title="x"></a>x</h3><p>这里又很神秘的,对于 <code>chunk_size</code> 这一维还要切一下。读取方式是之前说明过的交换维度,只要stride是对的,数据就是对的,只不过是写入顺序的问题罢了。所以读取进来的 <code>x</code> 是 <code>[pid_b, (k:k+1)*BLOCK_SIZE_K, pid_h, (pid_m:pid_m+1) * BLOCK_SIZE_M]</code></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">offs_m = pid_m * BLOCK_SIZE_M + tl.arange(<span class="number">0</span>, BLOCK_SIZE_M)</span><br><span class="line">offs_k = tl.arange(<span class="number">0</span>, BLOCK_SIZE_K)</span><br><span class="line">x_ptrs = x_ptr + (offs_m[:, <span class="literal">None</span>] * stride_x_hdim + offs_k[<span class="literal">None</span>, :] * stride_x_seqlen)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> k <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">0</span>, chunk_size_limit, BLOCK_SIZE_K):</span><br><span class="line"> x = tl.load(x_ptrs, mask=(offs_m[:, <span class="literal">None</span>] < hdim) & (offs_k[<span class="literal">None</span>, :] < chunk_size_limit - k), other=<span class="number">0.0</span>)</span><br><span class="line"> x_ptrs += BLOCK_SIZE_K * stride_x_seqlen</span><br></pre></td></tr></table></figure><h3 id="b"><a href="#b" class="headerlink" title="b"></a>b</h3><p><code>b</code> 的取值是 <code>[pid_b, (k: k + 1) * BLOCK_SIZE_K, pid_group, (pid_n: pid_n+1) * BLOCK_SIZE_N]</code></p><p>注意不要搞错了 <code>b</code> 的意义,这里的 <code>b</code> 可以看作是用来处理 <code>x</code> 的投影矩阵,还是要投到 dstate 上去的。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">offs_n = pid_n * BLOCK_SIZE_N + tl.arange(<span class="number">0</span>, BLOCK_SIZE_N)</span><br><span class="line">b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head</span><br><span class="line">b_ptrs = b_ptr + (offs_n[<span class="literal">None</span>, :] * stride_b_dstate + offs_k[:, <span class="literal">None</span>] * stride_b_seqlen)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> k <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">0</span>, chunk_size_limit, BLOCK_SIZE_K):</span><br><span class="line"> b = tl.load(b_ptrs, mask=(offs_k[:, <span class="literal">None</span>] < chunk_size_limit - k) & (offs_n[<span class="literal">None</span>, :] < dstate), other=<span class="number">0.0</span>).to(tl.float32)</span><br><span class="line"> b_ptrs += BLOCK_SIZE_K * stride_b_seqlen</span><br></pre></td></tr></table></figure><h3 id="dA-cs-k-与-dt-k"><a href="#dA-cs-k-与-dt-k" class="headerlink" title="dA_cs_k 与 dt_k"></a>dA_cs_k 与 dt_k</h3><p>就是读进来一个chunk_size的元素,然后进一步把 chunk_size(seq_len) 这一维切了一下,不太重要</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize</span><br><span class="line"><span class="keyword">for</span> k <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">0</span>, chunk_size_limit, BLOCK_SIZE_K):</span><br><span class="line"> dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=<span class="number">0.0</span>).to(tl.float32)</span><br><span class="line"> dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize</span><br></pre></td></tr></table></figure><h3 id="每一小块的具体计算"><a href="#每一小块的具体计算" class="headerlink" title="每一小块的具体计算"></a>每一小块的具体计算</h3><p>先把前缀和计算这一段的差,比如说 <code>dA_cs_last = A_3 * A_2 * A_1 * A_0</code>, <code>dA_cs_k = [A_0, A_1 * A_0]</code></p><p>这里的小块 <code>b</code> 是在 <code>[seq_len, dstate]</code> 维度上取出来的大小为 <code>[BLOCK_SIZE_K, BLOCK_SIZE_N]</code> 的块。这里的scale乘到了 <code>seq_len</code> 维度上,因为不同的时间步需要乘的A不同嘛。</p><p>后面和大小为 <code>[BLOCK_SIZE_M, BLOCK_SIZE_K]</code> 的 <code>x</code> 做点积,得到是一个 <code>[headdim, dstate]</code> 的矩阵。然后按照开头公式所说的加起来。最后存下来的states的维度是 <code>[batch, nchunks, nheads, headdim, dstate]</code></p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">scale = tl.exp(tl.minimum((dA_cs_last - dA_cs_k), <span class="number">0.0</span>)) * dt_k</span><br><span class="line">b *= scale[:, <span class="literal">None</span>]</span><br><span class="line">acc += tl.dot(x, b)</span><br></pre></td></tr></table></figure><h2 id="state-passing-fwd"><a href="#state-passing-fwd" class="headerlink" title="_state_passing_fwd"></a>_state_passing_fwd</h2><p>上面一步的计算其实是算少了的,比如说你第一步有的块是只有 $A_7A_6A_5A_4$的,那你本来想算 $A_{7:0}B_0x_0$ 的就少算了。方法就是从算过的 $A_{3:0}B_0x_0$ 往下传。函数名字也很形象。</p><p>代码上有一点点改变,首先就是把states的后两维聚到一起了,方便计算,然后 <code>dA_cumsum</code> 只需要取最后一个就可以。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">states, final_states = _state_passing_fwd(rearrange(states, <span class="string">"... p n -> ... (p n)"</span>), dA_cumsum[:, :, :, -<span class="number">1</span>],</span><br><span class="line"> initial_states=rearrange(initial_states, <span class="string">"... p n -> ... (p n)"</span>) <span class="keyword">if</span> initial_states <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span> <span class="keyword">else</span> <span class="literal">None</span>,</span><br><span class="line"> seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=C.dtype)</span><br></pre></td></tr></table></figure><h3 id="grid-2"><a href="#grid-2" class="headerlink" title="grid"></a>grid</h3><p>上来grid又是乱序,我真的会谢。<code>dim</code> 这一维指的是states的后两维,可能是这个算子里不涉及这个的操作,所以抛到最外侧了吧(但是内存空间又没变,早知如此为什么不之前存的时候就换顺序呢?)。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">grid = <span class="keyword">lambda</span> META: (triton.cdiv(dim, META[<span class="string">'BLOCK_SIZE'</span>]), batch, nheads)</span><br></pre></td></tr></table></figure><h3 id="states-amp-dA"><a href="#states-amp-dA" class="headerlink" title="states & dA"></a>states & dA</h3><p>函数的具体内容倒是简单,读入 <code>states</code> 和 <code>dA</code>。</p><p><code>dA</code> 的维度是 <code>[batch, nheads, nchunks]</code>;<code>state</code> 的维度是 <code>[batch, nchunks, nheads, headdim*dstate]</code></p><p>这里 <code>dA</code> 好像就是把一整个 <code>nchunks</code> 取出来,这样就是 <code>nchunks</code> 段乘积。</p><p><code>states</code> 则直接在除了nchunks的方向上移动到对应位置,最后一维取了一个 <code>BLOCK_SIZE</code> 做并行</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head</span><br><span class="line">states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head</span><br><span class="line">offs_m = pid_m * BLOCK_SIZE + tl.arange(<span class="number">0</span>, BLOCK_SIZE)</span><br><span class="line">states_ptrs = states_ptr + offs_m * stride_states_dim</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> c <span class="keyword">in</span> <span class="built_in">range</span>(nchunks):</span><br><span class="line"> new_states = tl.load(states_ptrs, mask=offs_m < dim, other=<span class="number">0.0</span>).to(tl.float32)</span><br><span class="line"> dA_cs = tl.load(dA_cs_ptr).to(tl.float32)</span><br></pre></td></tr></table></figure><p>一个比较朴素的问题是,怎么保证读的正确对应呢?这里用了一个循环来更新:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> c <span class="keyword">in</span> <span class="built_in">range</span>(nchunks):</span><br><span class="line"> new_states = tl.load(states_ptrs, mask=offs_m < dim, other=<span class="number">0.0</span>).to(tl.float32)</span><br><span class="line"> dA_cs = tl.load(dA_cs_ptr).to(tl.float32)</span><br><span class="line"> scale = tl.exp(dA_cs)</span><br><span class="line"></span><br><span class="line"> states = scale * states + new_states <span class="comment"># 循环更新 </span></span><br><span class="line"> <span class="keyword">if</span> c < nchunks - <span class="number">1</span>:</span><br><span class="line"> tl.store(out_ptrs, states, mask=offs_m < dim)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> tl.store(final_states_ptrs, states, mask=offs_m < dim)</span><br><span class="line"> states_ptrs += stride_states_chunk</span><br><span class="line"> dA_cs_ptr += stride_dA_cs_chunk</span><br><span class="line"> out_ptrs += stride_out_chunk</span><br></pre></td></tr></table></figure><p>但是说实话,具体怎么并行的我也有点没想清楚,为啥这么取就能把所有块的算好呢。后续又该怎么取呢?有点神秘。</p><h2 id="bmm-chunk-fwd"><a href="#bmm-chunk-fwd" class="headerlink" title="_bmm_chunk_fwd"></a>_bmm_chunk_fwd</h2><p>并非善类</p><h3 id="grid-3"><a href="#grid-3" class="headerlink" title="grid"></a>grid</h3><p>分的维度是 <code>[chunk_size, chunk_size, batch, nchunks]</code></p><p>Tridao瞎起名字,给我整笑了。外面传进去 <code>C</code> 和 <code>dstate</code>,在里面形参叫 <code>A</code> 和 <code>k</code>,</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">grid = <span class="keyword">lambda</span> META: (triton.cdiv(chunk_size, META[<span class="string">'BLOCK_SIZE_M'</span>]) * triton.cdiv(chunk_size, META[<span class="string">'BLOCK_SIZE_N'</span>]), batch, nchunks <span class="keyword">if</span> <span class="keyword">not</span> has_groups <span class="keyword">else</span> nchunks * ngroups)</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html">Tridao的tri是Triton的tri</summary>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/categories/Model-Structure/"/>
<category term="Mamba" scheme="https://anti-entrophic.github.io/tags/Mamba/"/>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/tags/Model-Structure/"/>
</entry>
<entry>
<title>Triton Tutorial</title>
<link href="https://anti-entrophic.github.io/posts/10042.html"/>
<id>https://anti-entrophic.github.io/posts/10042.html</id>
<published>2025-05-14T06:06:06.000Z</published>
<updated>2025-05-14T06:19:01.200Z</updated>
<content type="html"><![CDATA[<h1 id="绪论"><a href="#绪论" class="headerlink" title="绪论"></a>绪论</h1><p>Triton是一门适配python的高性能GPU编程语言(暂时只认为是语言),学习路线可以从完成官方的<a href="https://triton-lang.org/main/index.html," title="Triton Tutorial">tutorial</a>开始。我的博客里主要想讲一些不一样的。</p><p>CUDA Version: 12.2</p><p>Triton Version: 3.1.0</p><h1 id="GPU相关知识"><a href="#GPU相关知识" class="headerlink" title="GPU相关知识"></a>GPU相关知识</h1><p>想必大家上来跑tutorial遇到的第一个问题是,获取DEVICE的接口报错了!</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> triton.runtime <span class="keyword">import</span> driver</span><br><span class="line">DEVICE = driver.active.get_active_torch_device()</span><br><span class="line"></span><br><span class="line"><span class="comment"># >>> AttributeError: 'CudaDriver' object has no attribute 'get_active_torch_device'</span></span><br></pre></td></tr></table></figure><p>查阅源码后发现,应该是nvidia那边的接口变掉了,导致triton中无法重载:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># triton/python/triton/backends/driver.py</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">DriverBase</span>(metaclass=ABCMeta):</span><br><span class="line"></span><br><span class="line"><span class="meta"> @classmethod</span></span><br><span class="line"><span class="meta"> @abstractmethod</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">is_active</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="meta"> @abstractmethod</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">get_current_target</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="meta"> @abstractmethod</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">get_active_torch_device</span>(<span class="params">self</span>):</span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="meta"> @abstractmethod</span></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">get_benchmarker</span>(<span class="params">self</span>) -> Benchmarker:</span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Return the benchmarking function that this backend should use by default.</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="keyword">raise</span> NotImplementedError</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>) -> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line">driver.active</span><br><span class="line"><span class="comment"># >>> <nvi.CudaDriver object at 0x7ff294e43d30></span></span><br></pre></td></tr></table></figure><p>因此,我们可以用别的API来代替:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">driver.active.get_current_target()</span><br><span class="line"><span class="comment"># >>> GPUTarget(backend='cuda', arch=90, warp_size=32)</span></span><br><span class="line">DEVICE = driver.active.get_current_target().backend</span><br><span class="line"><span class="comment"># >>> 'cuda'</span></span><br><span class="line">DEVICE_ID = driver.active.get_current_device()</span><br><span class="line"><span class="comment"># >>> 0</span></span><br></pre></td></tr></table></figure><p>在有些时候,我们还需要更多GPU的信息来辅助并行编程:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">properties = driver.active.utils.get_device_properties(DEVICE_ID)</span><br><span class="line"><span class="comment"># >>> {'max_shared_mem': 232448, 'max_num_regs': 65536, 'multiprocessor_count': 132, 'warpSize': 32, 'sm_clock_rate': 1980000, 'mem_clock_rate': 2619000, 'mem_bus_width': 5120}</span></span><br><span class="line">NUM_SM = properties[<span class="string">"multiprocessor_count"</span>]</span><br><span class="line">NUM_REGS = properties[<span class="string">"max_num_regs"</span>]</span><br><span class="line">SIZE_SMEM = properties[<span class="string">"max_shared_mem"</span>]</span><br><span class="line">WARP_SIZE = properties[<span class="string">"warpSize"</span>]</span><br></pre></td></tr></table></figure><p>解释一下这四个主要的GPU参数:</p><ul><li><p><code>NUM_SM</code> 是指的GPU中Streaming Multiprocessor(SM)的数量。SM是GPU上的核心处理单元,包含完整的内存、寄存器等。整个CUDA编程的核心就是将任务分成多个BLOCKs,然后这些BLOCKs会被均匀地分给所有SM执行。我使用的GPU型号是H800,总共有132个SM。</p></li><li><p><code>NUM_REGS</code> 是指每个SM中可用的寄存器(registers)的最大数量。寄存器每个线程不共享,如果单个线程所需要使用的寄存器很多,则同时在一个SM上运行的线程数量就会减少。</p></li><li><p><code>SIZE_SMEM</code> 是指每个SM可用的共享内存(Shared Memory)的大小,通常以字节(Bytes)为单位。SMEM访问速度远快于显存,同一个SM中的所有线程均可共享,可用于数据交换等。linux系统中也有共享内存的概念,可以看作是“基于内存的文件系统”,通常指的是 <code>/dev/shm</code> 这块区域,用于进程间通信等操作,就不用把数据写到硬盘里去了。</p></li></ul><figure class="highlight cmd"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">df -h /dev/shm</span><br><span class="line">>>> 文件系统 大小 已用 可用 已用% 挂载点</span><br><span class="line">>>> tmpfs <span class="number">64</span>G <span class="number">0</span> <span class="number">64</span>G <span class="number">0</span>% /dev/shm</span><br></pre></td></tr></table></figure><ul><li><code>WARP_SIZE</code> 比较复杂一点,我们首先需要理解Warp的概念。Warp是GPU上线程调度的基本单元,一个Warp中的所有线程会执行相同的命令。这并不代表一个Warp中所有线程是完全一样的,而是说,如果Warp中有一半的指令做的是A,而另一半的指令做的是A->B,则在第一阶段所有线程会同时处理,而在第二阶段有一半的线程会陪着另一半空转。因此,避免Warp分歧也是一个很重要的优化点。<code>WARP_SIZE</code> 则表示一个Warp中包含的线程数量,基本上是32。</li></ul><h1 id="Vector-Addition"><a href="#Vector-Addition" class="headerlink" title="Vector Addition"></a>Vector Addition</h1><p>源代码很简单就不解释了</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> triton</span><br><span class="line"><span class="keyword">import</span> triton.language <span class="keyword">as</span> tl</span><br><span class="line"><span class="keyword">from</span> triton.runtime <span class="keyword">import</span> autotune</span><br><span class="line"></span><br><span class="line">DEVICE = triton.runtime.driver.active.get_current_target().backend</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@triton.jit</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">add_kernel</span>(<span class="params"></span></span><br><span class="line"><span class="params"> x_ptr,</span></span><br><span class="line"><span class="params"> y_ptr,</span></span><br><span class="line"><span class="params"> output_ptr,</span></span><br><span class="line"><span class="params"> n_elements,</span></span><br><span class="line"><span class="params"> BLOCK_SIZE: tl.constexpr</span></span><br><span class="line"><span class="params"></span>):</span><br><span class="line"> pid = tl.program_id(axis=<span class="number">0</span>)</span><br><span class="line"> block_start = pid * BLOCK_SIZE</span><br><span class="line"> offsets = block_start + tl.arange(<span class="number">0</span>, BLOCK_SIZE)</span><br><span class="line"> mask = offsets < n_elements</span><br><span class="line"> x = tl.load(x_ptr + offsets, mask=mask)</span><br><span class="line"> y = tl.load(y_ptr + offsets, mask=mask)</span><br><span class="line"> output = x + y</span><br><span class="line"> tl.store(output_ptr + offsets, output, mask=mask)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@autotune(<span class="params"></span></span></span><br><span class="line"><span class="params"><span class="meta"> configs=[</span></span></span><br><span class="line"><span class="params"><span class="meta"> triton.Config(<span class="params">{<span class="string">'BLOCK_SIZE'</span>: <span class="number">256</span>}, num_warps=<span class="number">4</span></span>),</span></span></span><br><span class="line"><span class="params"><span class="meta"> triton.Config(<span class="params">{<span class="string">'BLOCK_SIZE'</span>: <span class="number">512</span>}, num_warps=<span class="number">4</span></span>),</span></span></span><br><span class="line"><span class="params"><span class="meta"> triton.Config(<span class="params">{<span class="string">'BLOCK_SIZE'</span>: <span class="number">512</span>}, num_warps=<span class="number">8</span></span>),</span></span></span><br><span class="line"><span class="params"><span class="meta"> ],</span></span></span><br><span class="line"><span class="params"><span class="meta"> key=[<span class="string">'n_elements'</span>],</span></span></span><br><span class="line"><span class="params"><span class="meta"></span>)</span></span><br><span class="line"><span class="meta">@triton.jit</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">add_kernel_autotune</span>(<span class="params"></span></span><br><span class="line"><span class="params"> x_ptr,</span></span><br><span class="line"><span class="params"> y_ptr,</span></span><br><span class="line"><span class="params"> output_ptr,</span></span><br><span class="line"><span class="params"> n_elements,</span></span><br><span class="line"><span class="params"> BLOCK_SIZE: tl.constexpr</span></span><br><span class="line"><span class="params"></span>):</span><br><span class="line"> pid = tl.program_id(axis=<span class="number">0</span>)</span><br><span class="line"> block_start = pid * BLOCK_SIZE</span><br><span class="line"> offsets = block_start + tl.arange(<span class="number">0</span>, BLOCK_SIZE)</span><br><span class="line"> mask = offsets < n_elements</span><br><span class="line"> x = tl.load(x_ptr + offsets, mask=mask)</span><br><span class="line"> y = tl.load(y_ptr + offsets, mask=mask)</span><br><span class="line"> output = x + y</span><br><span class="line"> tl.store(output_ptr + offsets, output, mask=mask)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">add</span>(<span class="params"></span></span><br><span class="line"><span class="params"> x: torch.Tensor,</span></span><br><span class="line"><span class="params"> y: torch.Tensor,</span></span><br><span class="line"><span class="params"> block_size: <span class="built_in">int</span> = <span class="literal">None</span>,</span></span><br><span class="line"><span class="params"> num_warps: <span class="built_in">int</span> = <span class="number">4</span>,</span></span><br><span class="line"><span class="params"> autotune: <span class="built_in">bool</span> = <span class="literal">False</span></span></span><br><span class="line"><span class="params"></span>) -> torch.Tensor:</span><br><span class="line"> output = torch.empty_like(x)</span><br><span class="line"> <span class="keyword">assert</span> x.device.<span class="built_in">type</span> == DEVICE <span class="keyword">and</span> y.device.<span class="built_in">type</span> == DEVICE <span class="keyword">and</span> output.device.<span class="built_in">type</span> == DEVICE</span><br><span class="line"> n_elements = output.numel()</span><br><span class="line"></span><br><span class="line"> grid = <span class="keyword">lambda</span> meta: (triton.cdiv(n_elements, meta[<span class="string">'BLOCK_SIZE'</span>]), ) <span class="comment"># noqa: E731</span></span><br><span class="line"> <span class="keyword">if</span> autotune:</span><br><span class="line"> add_kernel_autotune[grid](x, y, output, n_elements)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size, num_warps=num_warps)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> output</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>():</span><br><span class="line"> torch.manual_seed(<span class="number">0</span>)</span><br><span class="line"> size = <span class="number">98432</span> </span><br><span class="line"> x = torch.randn(size, device=DEVICE)</span><br><span class="line"> y = torch.randn(size, device=DEVICE)</span><br><span class="line"></span><br><span class="line"> output_triton = add(x, y, block_size=<span class="number">256</span>)</span><br><span class="line"> output_pytorch = x + y</span><br><span class="line"></span><br><span class="line"> <span class="built_in">print</span>(output_triton)</span><br><span class="line"> <span class="built_in">print</span>(output_pytorch)</span><br><span class="line"> <span class="built_in">print</span>(<span class="string">f'<span class="subst">{torch.<span class="built_in">max</span>(torch.<span class="built_in">abs</span>(output_triton - output_pytorch))}</span>'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@triton.testing.perf_report(<span class="params"></span></span></span><br><span class="line"><span class="params"><span class="meta"> triton.testing.Benchmark(<span class="params"></span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> x_names=[<span class="string">'size'</span>],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> x_vals=[<span class="number">2</span>**i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="params"><span class="number">12</span>, <span class="number">28</span>, <span class="number">1</span></span>)],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> x_log=<span class="literal">True</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_arg=<span class="string">'provider'</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_vals=[<span class="string">'torch'</span>, <span class="string">'triton_bs128'</span>, <span class="string">'triton_bs256'</span>, <span class="string">'triton_bs512'</span>, <span class="string">'triton_nw4'</span>, <span class="string">'triton_nw8'</span>, <span class="string">'triton_nw16'</span>, <span class="string">'triton_autotune'</span>],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_names=[<span class="string">'Torch'</span>, <span class="string">'Triton BS=128'</span>, <span class="string">'Triton BS=256'</span>, <span class="string">'Triton BS=512'</span>, <span class="string">'Triton NW=4'</span>, <span class="string">'Triton NW=8'</span>, <span class="string">'Triton NW=16'</span>, <span class="string">'Triton Autotune'</span>],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> styles=[(<span class="params"><span class="string">'green'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'blue'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'red'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'purple'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'orange'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'cyan'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'magenta'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'yellow'</span>, <span class="string">'-'</span></span>)],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> ylabel=<span class="string">'GB/s'</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> plot_name=<span class="string">'vector-add-performance'</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> args={},</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> </span>)</span>)</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">benchmark</span>(<span class="params">size, provider</span>):</span><br><span class="line"> x = torch.rand(size, device=DEVICE, dtype=torch.float32)</span><br><span class="line"> y = torch.rand(size, device=DEVICE, dtype=torch.float32)</span><br><span class="line"> quantiles = [<span class="number">0.5</span>, <span class="number">0.2</span>, <span class="number">0.8</span>]</span><br><span class="line"> <span class="keyword">if</span> provider == <span class="string">'torch'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: x + y, quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_bs128'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">128</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_bs256'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">256</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_bs512'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">512</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_nw4'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">512</span>, num_warps=<span class="number">4</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_nw8'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">512</span>, num_warps=<span class="number">8</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_nw16'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, block_size=<span class="number">512</span>, num_warps=<span class="number">16</span>), quantiles=quantiles)</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_autotune'</span>:</span><br><span class="line"> ms, min_ms, max_ms = triton.testing.do_bench(<span class="keyword">lambda</span>: add(x, y, autotune=<span class="literal">True</span>), quantiles=quantiles)</span><br><span class="line"> gbps = <span class="keyword">lambda</span> ms: <span class="number">3</span> * x.numel() * x.element_size() * <span class="number">1e-9</span> / (ms * <span class="number">1e-3</span>) <span class="comment"># noqa: E731</span></span><br><span class="line"> <span class="keyword">return</span> gbps(ms), gbps(max_ms), gbps(min_ms)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">"__main__"</span>:</span><br><span class="line"> <span class="comment"># main()</span></span><br><span class="line"> benchmark.run(print_data=<span class="literal">True</span>, show_plots=<span class="literal">True</span>, save_path=<span class="string">"./results/01_vector_addition"</span>)</span><br></pre></td></tr></table></figure><h2 id="性能调优"><a href="#性能调优" class="headerlink" title="性能调优"></a>性能调优</h2><p>比起源码,我额外增加了 <code>from triton.runtime import autotune</code> ,它的作用就是对于不同size的输入,会在首次执行时搜一遍所有可能的配置,找到其中效率最高的,后续对同样的size就会用固定的配置。</p><p>简单尝试一下的话,就会发现影响程序性能的因素有两个 <code>BLOCK_SIZE</code> 与 <code>num_warps</code>。其实理论上应该是 <code>BLOCK_SIZE</code> 、 <code>num_warps</code> 和 <code>input_size</code> 三者的关系决定了性能。我们用控制变量的方式来测一下它们的影响。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAJ0iWgkKUDeOQt5_bzZKZovwbS5PBBHAALYrjEbdlsgRTJ8-RKWuAOmAQADAgADeAADNgQ.png" width="600px" /><p style="font-size: 10px;">Vector Addition Triton Kernel Performance。测BS时默认NW为4;测NW时默认BS为512</p></center><p><br></p><p>从图中我们能看出两个明显掉点的曲线,一个是 <code>BLOCK_SIZE</code> 为128时,此时的 <code>BLOCK_SIZE</code> 太小,分成的warps太多,导致管理调度 $2^{27} / 128 / 32$ 个块成了开销瓶颈;另一个是 <code>num_warps</code> 为16时,过大的线程块导致了对寄存器等资源的竞争更加剧烈,甚至可能发生寄存器溢出,显著影响效率。其实我也还不太会分析具体的原因,下面是详细的测试表格:</p><div class="table-container"><table><thead><tr><th>size</th><th>Torch</th><th>Triton BS=128</th><th>Triton BS=256</th><th>Triton BS=512</th><th>Triton NW=4</th><th>Triton NW=8</th><th>Triton NW=16</th><th>Triton Autotune</th></tr></thead><tbody><tr><td>4096.000000</td><td>9.035294</td><td>9.197604</td><td>9.142857</td><td>9.142857</td><td>9.142857</td><td>9.142857</td><td>9.088757</td><td>8.982456</td></tr><tr><td>8192.000000</td><td>17.964912</td><td>18.070588</td><td>17.860465</td><td>17.757226</td><td>18.070588</td><td>17.757226</td><td>18.070588</td><td>17.757226</td></tr><tr><td>16384.000000</td><td>35.514452</td><td>35.310345</td><td>35.310345</td><td>35.310345</td><td>35.310345</td><td>35.310345</td><td>35.720930</td><td>35.310345</td></tr><tr><td>32768.000000</td><td>68.266666</td><td>68.266666</td><td>69.423731</td><td>69.033707</td><td>69.033707</td><td>70.217145</td><td>69.818181</td><td>69.818181</td></tr><tr><td>65536.000000</td><td>132.843245</td><td>135.779009</td><td>135.779009</td><td>133.565214</td><td>133.565214</td><td>133.565214</td><td>135.032965</td><td>136.533331</td></tr><tr><td>131072.000000</td><td>253.360834</td><td>252.061538</td><td>255.999991</td><td>258.694729</td><td>260.063494</td><td>252.061538</td><td>253.360834</td><td>258.694729</td></tr><tr><td>262144.000000</td><td>457.227922</td><td>444.814490</td><td>455.111110</td><td>465.895721</td><td>463.698115</td><td>461.521112</td><td>453.013839</td><td>465.895721</td></tr><tr><td>524288.000000</td><td>750.412251</td><td>747.558951</td><td>774.047204</td><td>771.011790</td><td>771.011790</td><td>765.011652</td><td>741.916954</td><td>777.106702</td></tr><tr><td>1048576.000000</td><td>1228.800031</td><td>1159.929234</td><td>1187.963788</td><td>1221.167675</td><td>1228.800031</td><td>1184.385557</td><td>1156.517652</td><td>1217.387051</td></tr><tr><td>2097152.000000</td><td>1687.622326</td><td>1569.724635</td><td>1676.827323</td><td>1698.557221</td><td>1687.622326</td><td>1684.008546</td><td>1569.724635</td><td>1694.896509</td></tr><tr><td>4194304.000000</td><td>2154.608134</td><td>1927.529447</td><td>2148.721353</td><td>2160.527432</td><td>2163.499294</td><td>2145.789924</td><td>1852.607766</td><td>2163.499294</td></tr><tr><td>8388608.000000</td><td>2532.792214</td><td>2204.434417</td><td>2538.924930</td><td>2534.833158</td><td>2532.792214</td><td>2543.029933</td><td>2160.527432</td><td>2545.087416</td></tr><tr><td>16777216.000000</td><td>2755.784589</td><td>2349.311512</td><td>2763.045981</td><td>2756.388411</td><td>2755.784589</td><td>2763.045981</td><td>2394.920460</td><td>2758.200902</td></tr><tr><td>33554432.000000</td><td>2941.994716</td><td>2428.196121</td><td>2949.580950</td><td>2944.059933</td><td>2942.682906</td><td>2950.618366</td><td>2558.022378</td><td>2949.580950</td></tr><tr><td>67108864.000000</td><td>3021.832823</td><td>2487.232999</td><td>3032.757656</td><td>3021.832823</td><td>3022.740106</td><td>3032.392472</td><td>2644.858132</td><td>3032.027035</td></tr><tr><td>134217728.000000</td><td>3067.880595</td><td>2514.570784</td><td>3071.437476</td><td>3066.945668</td><td>3067.506556</td><td>3072.937670</td><td>2676.788108</td><td>3072.562746</td></tr></tbody></table></div><div class="note warning flat"><p>如果采用 <code>BS=4096,NW=16</code> 的配置的话,性能并不会显著下降,看起来似乎不大符合原来“寄存器溢出导致性能下降”的假设</p><p>gemini的回答是:</p><ul><li><p>当 Kernel 需要处理一个包含4096个元素的向量 (如 offsets = tl.arange(0, BLOCK_SIZE)) 时,编译器清楚地知道,不可能一次性将所有4096个元素(对于每个线程来说,是它负责的那部分)都完整地、同时地放在寄存器中进行操作。这远超出了单个线程或 Warp 的实际寄存器容量。因此,编译器很可能会生成一种高度流水线化 (pipelined) 或分块 (tiled/chunked) 的执行代码。它会将这4096个元素的操作分解成更小的、可管理的批次。例如,加载一小批数据到寄存器,计算,存储结果,然后再处理下一小批。<br>这种流水线化的处理方式,其瞬时 (instantaneous) 寄存器需求可能相对较低。也就是说,在任何特定时刻,每个线程为正在处理的小数据块所活跃使用的寄存器数量可能不多。</p></li><li><p>编译器处理512个元素时,它可能尝试一种不同的优化策略。也许它认为可以将更大部分的“向量”同时保持活跃在寄存器中,或者采用一种不那么积极分解成小块的策略,因为它认为这在某些情况下(如有足够寄存器时)可能更有效。<br>这种策略如果本身对寄存器的需求就比较高,那么当 num_warps 强制一个较紧的寄存器预算时,就更容易导致寄存器溢出。</p></li></ul><p>暂时不知道对错,感觉我还需要一些性能调优工具。</p></div><p>我们可以发现一个有趣的现象,即固定 <code>num_warps</code> 时,<code>BLOCK_SIZE</code> 设为256是最佳的;固定 <code>BLOCK_SIZE</code> 时,<code>num_warps</code> 设为8是最佳的。</p><p>由于性能受到多种因素的影响,因此影响程序效率的参数往往存在着被称为 <strong>sweet spot(甜点区)</strong> 的最佳范围,高了也不行,低了也不行。</p><p>一个比较好的方法是采用autotune,试几种可选参数后丢进去。这样程序会在初次运行时自动测试几种配置,后续对于同样的size均采用最佳配置即可。</p><p>值得一提的是,在计算非常简单的 Vector Addition 的场景下,性能瓶颈不在计算而在于带宽,因此最佳性能差距不会太大。</p><h1 id="Fused-Softmax"><a href="#Fused-Softmax" class="headerlink" title="Fused Softmax"></a>Fused Softmax</h1><p>遇到的一个神人问题是,源代码中的 <code>kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)</code> 报错</p><p>报错原因是,代码中已经预先编译过内核了</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">kernel = softmax_kernel.warmup(y, x, x.stride(<span class="number">0</span>), y.stride(<span class="number">0</span>), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(<span class="number">1</span>, ))</span><br></pre></td></tr></table></figure><p>而在 <code>softmax_kernel</code> 的原始定义中,<code>BLOCK_SIZE</code> 和 <code>num_stages</code> 被声明为了 <code>tl.constexpr</code> 类型,表示编译时常量,即在内核中编译的过程中已经按照这个常量编译了,不应该在运行时再次传入。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">@triton.jit</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">softmax_kernel</span>(<span class="params">output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr</span>):</span><br></pre></td></tr></table></figure><p>因此,源代码应该修改为</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">kernel[(num_programs, <span class="number">1</span>, <span class="number">1</span>)](y, x, x.stride(<span class="number">0</span>), y.stride(<span class="number">0</span>), n_rows, n_cols)</span><br></pre></td></tr></table></figure><p>最后,完整的代码如下:</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> triton</span><br><span class="line"><span class="keyword">import</span> triton.language <span class="keyword">as</span> tl</span><br><span class="line"><span class="keyword">from</span> triton.runtime <span class="keyword">import</span> driver</span><br><span class="line"></span><br><span class="line">target = driver.active.get_current_target()</span><br><span class="line">DEVICE = target.backend</span><br><span class="line">DEVICE_ID = driver.active.get_current_device()</span><br><span class="line"></span><br><span class="line">properties = driver.active.utils.get_device_properties(DEVICE_ID)</span><br><span class="line">NUM_SM = properties[<span class="string">"multiprocessor_count"</span>]</span><br><span class="line">NUM_REGS = properties[<span class="string">"max_num_regs"</span>]</span><br><span class="line">SIZE_SMEM = properties[<span class="string">"max_shared_mem"</span>]</span><br><span class="line">WARP_SIZE = properties[<span class="string">"warpSize"</span>]</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">naive_softmax</span>(<span class="params">x</span>):</span><br><span class="line"> <span class="comment"># read MN elements, write M elements</span></span><br><span class="line"> x_max = x.<span class="built_in">max</span>(dim=<span class="number">1</span>)[<span class="number">0</span>]</span><br><span class="line"> <span class="comment"># read MN+M elements, write MN elements</span></span><br><span class="line"> z = x - x_max[:, <span class="literal">None</span>]</span><br><span class="line"> <span class="comment"># read MN elements, write MN elements</span></span><br><span class="line"> numerator = torch.exp(z)</span><br><span class="line"> <span class="comment"># read MN elements, write M elements</span></span><br><span class="line"> denominator = numerator.<span class="built_in">sum</span>(dim=<span class="number">1</span>)</span><br><span class="line"> <span class="comment"># read MN + M elements, write MN elements</span></span><br><span class="line"> ret = numerator / denominator[:, <span class="literal">None</span>]</span><br><span class="line"> <span class="comment"># in total: read 5MN + 2M elements ; wrote 3MN + 2M elements</span></span><br><span class="line"> <span class="keyword">return</span> ret</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@triton.jit</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">softmax_kernel</span>(<span class="params">output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr</span>):</span><br><span class="line"> row_start = tl.program_id(<span class="number">0</span>)</span><br><span class="line"> row_step = tl.num_programs(<span class="number">0</span>)</span><br><span class="line"> <span class="keyword">for</span> row_idx <span class="keyword">in</span> tl.<span class="built_in">range</span>(row_start, n_rows, row_step, num_stages=num_stages):</span><br><span class="line"> row_start_ptr = input_ptr + row_idx * input_row_stride</span><br><span class="line"> col_offsets = tl.arange(<span class="number">0</span>, BLOCK_SIZE)</span><br><span class="line"></span><br><span class="line"> input_ptrs = row_start_ptr + col_offsets</span><br><span class="line"> mask = col_offsets < n_cols</span><br><span class="line"> row = tl.load(input_ptrs, mask=mask, other=-<span class="built_in">float</span>(<span class="string">'inf'</span>))</span><br><span class="line"></span><br><span class="line"> row_minus_max = row - tl.<span class="built_in">max</span>(row, axis=<span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> numerator = tl.exp(row_minus_max)</span><br><span class="line"> denominator = tl.<span class="built_in">sum</span>(numerator, axis=<span class="number">0</span>)</span><br><span class="line"> softmax_output = numerator / denominator</span><br><span class="line"></span><br><span class="line"> output_row_start_ptr = output_ptr + row_idx * output_row_stride</span><br><span class="line"> output_ptrs = output_row_start_ptr + col_offsets</span><br><span class="line"> tl.store(output_ptrs, softmax_output, mask=mask)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">softmax</span>(<span class="params">x, num_stages_to_use</span>):</span><br><span class="line"> n_rows, n_cols = x.shape</span><br><span class="line"> BLOCK_SIZE = triton.next_power_of_2(n_cols)</span><br><span class="line"></span><br><span class="line"> num_warps = <span class="number">8</span></span><br><span class="line"> <span class="comment"># num_stages = 4 if SIZE_SMEM > 200000 else 2</span></span><br><span class="line"> num_stages = num_stages_to_use</span><br><span class="line"></span><br><span class="line"> y = torch.empty_like(x)</span><br><span class="line"> kernel = softmax_kernel.warmup(y, x, x.stride(<span class="number">0</span>), y.stride(<span class="number">0</span>), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(<span class="number">1</span>, ))</span><br><span class="line"> kernel._init_handles()</span><br><span class="line"></span><br><span class="line"> n_regs = kernel.n_regs</span><br><span class="line"> size_smem = kernel.metadata.shared</span><br><span class="line"></span><br><span class="line"> occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)</span><br><span class="line"> occupancy = <span class="built_in">min</span>(occupancy, SIZE_SMEM // size_smem)</span><br><span class="line"> num_programs = NUM_SM * occupancy</span><br><span class="line"></span><br><span class="line"> num_programs = <span class="built_in">min</span>(num_programs, n_rows)</span><br><span class="line"></span><br><span class="line"> kernel[(num_programs, <span class="number">1</span>, <span class="number">1</span>)](y, x, x.stride(<span class="number">0</span>), y.stride(<span class="number">0</span>), n_rows, n_cols)</span><br><span class="line"> <span class="keyword">return</span> y</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">main</span>():</span><br><span class="line"> torch.manual_seed(<span class="number">0</span>)</span><br><span class="line"> x = torch.randn(<span class="number">1823</span>, <span class="number">781</span>, device=DEVICE)</span><br><span class="line"> y_triton = softmax(x, <span class="number">4</span>)</span><br><span class="line"> y_torch = torch.softmax(x, axis=<span class="number">1</span>)</span><br><span class="line"> <span class="keyword">assert</span> torch.allclose(y_triton, y_torch), (y_triton, y_torch)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="meta">@triton.testing.perf_report(<span class="params"></span></span></span><br><span class="line"><span class="params"><span class="meta"> triton.testing.Benchmark(<span class="params"></span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> x_names=[<span class="string">'N'</span>], <span class="comment"># argument names to use as an x-axis for the plot</span></span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> x_vals=[<span class="number">128</span> * i <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="params"><span class="number">2</span>, <span class="number">129</span></span>)], <span class="comment"># different possible values for `x_name`</span></span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_arg=<span class="string">'provider'</span>, <span class="comment"># argument name whose value corresponds to a different line in the plot</span></span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_vals=[<span class="string">'torch'</span>, <span class="string">'triton_ns2'</span>, <span class="string">'triton_ns3'</span>, <span class="string">'triton_ns4'</span>],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> line_names=[</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> <span class="string">"Torch"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> <span class="string">"Triton (NS=2)"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> <span class="string">"Triton (NS=3)"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> <span class="string">"Triton (NS=4)"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> ],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> styles=[(<span class="params"><span class="string">'green'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'blue'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'red'</span>, <span class="string">'-'</span></span>), (<span class="params"><span class="string">'purple'</span>, <span class="string">'-'</span></span>)],</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> ylabel=<span class="string">"GB/s"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> plot_name=<span class="string">"softmax-performance-vs-num_stages"</span>,</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> args={<span class="string">'M'</span>: <span class="number">4096</span>},</span></span></span></span><br><span class="line"><span class="params"><span class="params"><span class="meta"> </span>)</span>)</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">benchmark</span>(<span class="params">M, N, provider</span>):</span><br><span class="line"> x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)</span><br><span class="line"> stream = <span class="built_in">getattr</span>(torch, DEVICE).Stream()</span><br><span class="line"> <span class="built_in">getattr</span>(torch, DEVICE).set_stream(stream)</span><br><span class="line"> <span class="keyword">if</span> provider == <span class="string">'torch'</span>:</span><br><span class="line"> ms = triton.testing.do_bench(<span class="keyword">lambda</span>: torch.softmax(x, axis=-<span class="number">1</span>))</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_ns2'</span>:</span><br><span class="line"> ms = triton.testing.do_bench(<span class="keyword">lambda</span>: softmax(x, num_stages_to_use=<span class="number">2</span>))</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_ns3'</span>:</span><br><span class="line"> ms = triton.testing.do_bench(<span class="keyword">lambda</span>: softmax(x, num_stages_to_use=<span class="number">3</span>))</span><br><span class="line"> <span class="keyword">elif</span> provider == <span class="string">'triton_ns4'</span>:</span><br><span class="line"> ms = triton.testing.do_bench(<span class="keyword">lambda</span>: softmax(x, num_stages_to_use=<span class="number">4</span>))</span><br><span class="line"> gbps = <span class="keyword">lambda</span> ms: <span class="number">2</span> * x.numel() * x.element_size() * <span class="number">1e-9</span> / (ms * <span class="number">1e-3</span>) <span class="comment"># noqa: E731</span></span><br><span class="line"> <span class="keyword">return</span> gbps(ms)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">"__main__"</span>:</span><br><span class="line"> <span class="comment"># main()</span></span><br><span class="line"> benchmark.run(show_plots=<span class="literal">True</span>, print_data=<span class="literal">True</span>, save_path=<span class="string">"./results/02_fused_softmax"</span>)</span><br><span class="line"></span><br></pre></td></tr></table></figure><h2 id="代码说明"><a href="#代码说明" class="headerlink" title="代码说明"></a>代码说明</h2><p>这个kernel比 <code>Vector Addition</code> 要稍微复杂一点,因为它涉及一个伪2D的并行。</p><p>说是伪2D是因为,它的 <code>BLOCK_SIZE</code> 取值是 <code>triton.next_power_of_2(n_cols)</code>,因此,每行不需要再单独划分。</p><p>主要的循环 <code>for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):</code> 的功能是:</p><ul><li><p>如果总共有 P 个程序并行,当前程序为 i,则取出第 $[i, P+i, 2P+i, \cdots]$ 行处理</p></li><li><p>后续每行取 <code>BLOCK_SIZE</code> 其实就已经取完了,并没有在行上并行</p></li></ul><p>对于 <code>@triton.jit</code> 装饰的函数,其中的中间变量,会放在寄存器或者共享内存中,直到 <code>tl.store()</code> 才会写回(如果占太多了也有可能自动offload到显存)。共享内存一般是SRAM,硬件特性决定了比一般的显存(HBM)要快很多。</p><p><code>num_stages</code> 是一个流水线处理的参数,表示并行程度。如果 <code>num_stages>1</code>,则在程序处理第一条数据时,会同时开始加载第二条数据,这样可以加速,但是对SRAM的大小又有了要求。</p><p>因此有 <code>num_stages = 4 if SIZE_SMEM > 200000 else 2</code></p><h2 id="benchmark"><a href="#benchmark" class="headerlink" title="benchmark"></a>benchmark</h2><p>从图中可以看出,还是要快不少的,不同 <code>num_stages</code> 之间差距不大。</p><center><img src="https://i.072333.xyz/file/AgACAgEAAyEGAASMaMWHAAJ0j2gkMOUL6w_k6baDcCVISyal1o4NAALjrjEbdlsgRaE4ZMJECk62AQADAgADeAADNgQ.png" width="600px" /><p style="font-size: 10px;">Fused Softmax Triton Kernel Performance</p></center><p><br></p>]]></content>
<summary type="html">都什么年代,还在用传统pytorch</summary>
<category term="cuda" scheme="https://anti-entrophic.github.io/categories/cuda/"/>
<category term="CUDA" scheme="https://anti-entrophic.github.io/tags/CUDA/"/>
<category term="triton" scheme="https://anti-entrophic.github.io/tags/triton/"/>
</entry>
<entry>
<title>Part III of Mathematical Structure of Mamba - S4D</title>
<link href="https://anti-entrophic.github.io/posts/10040.html"/>
<id>https://anti-entrophic.github.io/posts/10040.html</id>
<published>2025-04-25T08:28:02.000Z</published>
<updated>2025-07-04T05:47:46.551Z</updated>
<content type="html"><![CDATA[<div class="note success flat"><p>本篇是mamba系列blog的第三篇文章,系列文章见:</p><ul><li><p><a href="https://anti-entrophic.github.io/posts/10038.html" title="Part I of Mathematical Structure of Mamba - Hippo">Part I of Mathematical Structure of Mamba - Hippo</a></p></li><li><p><a href="https://anti-entrophic.github.io/posts/10039.html" title="Part II of Mathematical Structure of Mamba - S4">Part II of Mathematical Structure of Mamba - S4</a></p></li><li><p>Part III of Mathematical Structure of Mamba - S4D</p></li><li><p><a href="https://anti-entrophic.github.io/posts/10043.html" title="Part IV of Mathematical Structure of Mamba - Mamba&Mamba2">Part IV of Mathematical Structure of Mamba - Mamba&Mamba2</a></p></li></ul><p>剩余预计还有两篇文章正在生产中~</p></div><h1 id="Motivation"><a href="#Motivation" class="headerlink" title="Motivation"></a>Motivation</h1><p>S4的突出贡献就在于给出了HiPPO矩阵及相关卷积核高阶幂的求解办法 <strong>DPLR</strong>,但是从推导就可以看出,仍然是过于复杂了。</p><p>前置工作 DSS 发现,存在使用对角线状态矩阵简化原来的HiPPO矩阵的可能性,不过还是引入了一些有难度的操作(需要复值softmax,并且没有解释为什么对角线矩阵可以work)</p><p>本篇文章就进一步梳理了SSM使用对角线状态方程的最佳表达形式。如果A只有对角线的话那A的高阶幂就太好算了。</p><h1 id="零阶保持离散化-ZOH"><a href="#零阶保持离散化-ZOH" class="headerlink" title="零阶保持离散化(ZOH)"></a>零阶保持离散化(ZOH)</h1><p>我们首先来介绍一下零阶保持离散化方法</p><h2 id="状态方程"><a href="#状态方程" class="headerlink" title="状态方程"></a>状态方程</h2><script type="math/tex; mode=display">\frac{dh(t)}{dt} = Ah(t) + Bx(t)</script><h2 id="求解"><a href="#求解" class="headerlink" title="求解"></a>求解</h2><p>使用<strong>积分因子法</strong>,首先将方程整理为标准形式:</p><script type="math/tex; mode=display">\frac{dh(t)}{dt} - Ah(t) = Bx(t)</script><p>接下来计算积分因子 $\mu(t)$:</p><script type="math/tex; mode=display">\mu(t) = exp(\int -Adt) = e^{-At}</script><p>将积分因子乘以方程两边:</p><script type="math/tex; mode=display">e^{-At}\frac{dh(t)}{dt}-Ae^{-At}h(t) = Be^{-At}x(t)</script><p>左边变为:</p><script type="math/tex; mode=display">\frac{d}{dt}(e^{-At}h(t)) = Be^{-At}x(t)</script><p>对两边积分:</p><script type="math/tex; mode=display">\int_{t_0}^t \frac{d}{d\tau}(e^{-A\tau}h(\tau))d\tau=\int_{t_0}^{t}Be^{-A\tau}x(\tau)d\tau</script><p>积分后得到:</p><script type="math/tex; mode=display">e^{-At}h(t) - e^{-At_0}h(t_0) = \int_{t_0}^{t}Be^{-A\tau}x(\tau)d\tau</script><p>解出 $h(t)$:</p><script type="math/tex; mode=display">h(t) = e^{A(t-t_0)}h(t_0) + \int_{t_0}^tBe^{A(t-\tau)}x(\tau)d\tau</script><h2 id="离散化"><a href="#离散化" class="headerlink" title="离散化"></a>离散化</h2><p>我们考虑 $ t = t_0 + \Delta$,则 $x(\tau)$ 在这段时间内保持恒定,不妨设为 $x_k$, 则</p><script type="math/tex; mode=display">\begin{aligned}h(t+\Delta) &= e^{A\Delta}h(t) + \int_t^{t+\Delta}Be^{A(t-\tau)}x_kd\tau \\&= e^{A\Delta}h(t) + (\int_t^{t+\Delta}e^{A(t-\tau)}d\tau)Bx_k \\&= e^{A\Delta}h(t) + (\int_0^{\Delta}e^{A\tau}d\tau)Bx_k\end{aligned}</script><p>其中,$\int_0^{\Delta}e^{A\tau}d\tau$ 的结果可以如下计算得到:</p><script type="math/tex; mode=display">\frac{d}{dt}e^{At} = Ae^{At}</script><p>两边积分得:</p><script type="math/tex; mode=display">\int Ae^{At}dt = e^{At} + C</script><p>左乘 $A^{-1}$ 得:</p><script type="math/tex; mode=display">\int e^{At}dt = A^{-1}e^{At} + C</script><p>对于定积分 $[0,\Delta]$,结果是:</p><script type="math/tex; mode=display">\int_0^{\Delta}e^{A\tau}d\tau = A^{-1}(e^{A\Delta}-I)</script><p>综上,代入积分结果,原来的离散化方程为:</p><script type="math/tex; mode=display">h(t+\Delta) = e^{A\Delta}h(t) + A^{-1}(e^{A\Delta}-I)Bx_k</script><p>所以</p><script type="math/tex; mode=display">\begin{aligned}\bar{A} &= e^{A\Delta} \\\bar{B} &= A^{-1}(e^{A\Delta}-I)B\end{aligned}</script><p>而 SSM 中的 $C$ 和 $D$ 是不需要离散化的,因为它本来就是在离散空间上更新的。其实也不一定需要是 $x’ = Ch + Dx$ 吧,也可以是 $y = Ch + Dx$,SSM最重要的只是维护一个状态罢了——这个状态是对记忆空间正交基的投影。</p><p>之前S4中用双线性离散化,就是因为不是对角线形式的 $\bar{A} = e^{A\Delta}$ 不好算。</p><h2 id="矩阵指数"><a href="#矩阵指数" class="headerlink" title="矩阵指数"></a>矩阵指数</h2><p>矩阵指数的定义参考标量指数函数的泰勒展开:</p><script type="math/tex; mode=display">e^{At} = I + At + \frac{(At)^2}{2!}+...</script><p>系统的稳定性主要由矩阵指数 $e^{At}$ 决定,分析 $e^{At}$ 的敛散性比较复杂且经典,这里直接给出结论 —— 由矩阵 $A$ 的特征值的实部的正负决定</p><ul><li>当矩阵 $A$ 的所有特征值的实部都小于0时,矩阵指数函数 $e^{At}$ 在 $t \rightarrow \infty$ 时收敛到零矩阵。</li></ul><h3 id="证明"><a href="#证明" class="headerlink" title="证明"></a>证明</h3><div class="note warning flat"><p>意会一下,以后补上完整证明</p></div><p>任何矩阵 $A$ 都可以分解为 Jordan 标准型 $A = PJP^{-1}$,其中 $J$ 是 Jordan 矩阵。因此,$e^{At}=Pe^{Jt}P^{-1}$</p><p>每个 Jordan 块对应一个特征值 $\lambda$,其矩阵指数包含形如 $t^ke^{\lambda t}$的项(其中 $k$ 是多项式项的阶数)。当 $\lambda$ 的实部小于0时,指数衰减 $e^{\lambda t}$ 会压制多项式增长 $t^k$,使得每一项在 $t \rightarrow \infty$ 时趋于0</p><p>因此整个Jordan 矩阵的指数 $e^{Jt}$ 趋向于零矩阵,进而 $e^{At} = Pe^{Jt}P^{-1}$ 也趋向于零矩阵。</p><h2 id="Left-half-plane-condition"><a href="#Left-half-plane-condition" class="headerlink" title="Left-half plane condition"></a>Left-half plane condition</h2><p>因此,为了限制 $A$ 的实部大小,一个常见的方法是用指数函数来创建 $A$ 的实部:$A = -e^{A_{Re}} + i \cdot A_{Im}$</p><p>这样实部就总是负的,天才。</p><p>Albert Gu 在这篇文章中指出,不一定要用指数函数,用一般的激活函数比如说 ReLU, softplus也是可以的。现在transformers库对mamba的实现中就是用的exp的形式。</p><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># transformers/models/mamba2/modeling_mamba2.py</span></span><br><span class="line">A = -torch.exp(self.A_log.<span class="built_in">float</span>())</span><br></pre></td></tr></table></figure>]]></content>
<summary type="html">(NIPS 22) S4D - On the Parameterization and Initialization of Diagonal State Space Models</summary>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/categories/Model-Structure/"/>
<category term="Mamba" scheme="https://anti-entrophic.github.io/tags/Mamba/"/>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/tags/Model-Structure/"/>
<category term="HiPPO" scheme="https://anti-entrophic.github.io/tags/HiPPO/"/>
<category term="S4" scheme="https://anti-entrophic.github.io/tags/S4/"/>
</entry>
<entry>
<title>Sparsemax</title>
<link href="https://anti-entrophic.github.io/posts/10035.html"/>
<id>https://anti-entrophic.github.io/posts/10035.html</id>
<published>2025-04-11T06:20:42.000Z</published>
<updated>2025-09-08T11:55:24.392Z</updated>
<content type="html"><![CDATA[<p>介绍一篇修改softmax的文章,题目是:</p><p><a href="https://arxiv.org/pdf/1602.02068" title="950引">[ICML 2016] From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification</a></p><p>它允许softmax非平滑输出,某些位置的prob直接为0</p><h1 id="method"><a href="#method" class="headerlink" title="method"></a>method</h1><p>原来的softmax公式是通过下面的计算得到概率分布</p><script type="math/tex; mode=display">\text{softmax}_i(z) = \frac{e^{z_i}}{\sum_j e^{z_j}}</script><p>现在,作者考虑直接将隐状态向一个 $K-1$ 维的单纯形投影,求得的 $p$ 即是softmax的结果</p><script type="math/tex; mode=display">\begin{aligned}\Delta^{K-1} &:=\{p\in \mathbb{R}^K | 1^Tp=1, p \geq 0\} \\\text{sparsemax}(z) &:= \mathop{\arg\max}\limits_{p \in \Delta^{K-1}} ||p-z||^2\end{aligned}</script><h2 id="为什么是K-1维"><a href="#为什么是K-1维" class="headerlink" title="为什么是K-1维"></a>为什么是K-1维</h2><p>因为有 $1^Tp=1$ 的限制,导致自由度减1</p><p>和 $x+y+z=1$ 表示的是一个平面而不是三维空间是一个道理</p><h2 id="如何求解"><a href="#如何求解" class="headerlink" title="如何求解"></a>如何求解</h2><p>将原问题转化为一个优化问题:</p><script type="math/tex; mode=display">\begin{aligned}\text{min} \quad & ||p-z||^2 \\s.t. \quad & 1^Tp=1 \\& p \geq 0\end{aligned}</script><p>考虑求解其拉格朗日对偶问题:</p><script type="math/tex; mode=display">L(p, \lambda, \mu) = \frac{1}{2} ||p-z||^2 - \lambda^Tp + \mu(1^Tp-1)</script><p>这里系数 $\frac{1}{2}$ 是为了方便求导后消去系数</p><p>考虑KKT条件:</p><script type="math/tex; mode=display">\left\{\begin{aligned}\nabla_p L(p^*, \lambda^*, \mu^*) &= p^*-z-\lambda^*+\mu^*1 \qquad&(1)\\ 1^Tp^* &= 1 &(2)\\ p^* &\geq 0 &(3)\\\lambda^* &\geq 0 &(4)\\\lambda^*p^* &= 0 &(5)\end{aligned}\right.</script><p>对于 $p_i^{\ast} > 0$,由于式(5),此时必有 $\lambda_i^{\ast}=0$,所以由式(1),得:</p><script type="math/tex; mode=display">p_i^* = z_i - \mu, \quad s.t.\quad p_i^* > 0</script><p>由式(2),得:</p><script type="math/tex; mode=display">\begin{aligned}\sum_{i \in K}p_i^* &= 0 + \sum_{i \in S(z)}p_i^* \\1 &= \sum_{i \in S(z)}(z_i - \mu) \\\mu &= \frac{\sum_{i \in S(z)}z_i - 1}{|S(z)|}\end{aligned}</script><p>其中,$S(z) = \{j \in K \, | \,\, p_j^*>0\}$</p><p>由此,我们知道了如何从 $z$ 得到 $p$,即:</p><script type="math/tex; mode=display">p = \text{sparsemax}(z) = [z - \mu]_+</script><p>求出 $\mu$ 的关键,是求出多大 $|S(z)|$ 才能正好满足:</p><ul><li><p>$z_j^* - \mu > 0, \quad \forall j \in S(z)$</p></li><li><p>$z_j^* - \mu \leq 0, \quad \forall j \notin S(z)$</p></li></ul><p>注意 $\mu$ 不随着 $|S(z)|$ 的增加而单调,所以不能<strong>二分</strong>求解,朴素地用 $O(K)$ 的复杂度遍历寻找 $\mu$</p><div class="note success flat"><p>但是 $\Delta \mu$ 是单调的,所以理应有更快速的搜索方法</p><p>如果只是在 lm_head 位置对词表大小 vocab_size 做 softmax,那可能确实速度不要紧</p><p>但是如果是对attention做softmax就完蛋了,因为它要排序,排序的复杂度是 $O(N \log N)$。这样对于长度为 $N$ 的 attention,复杂度直接到 $O(KN\log N)$ 了,爆了</p><p>所以如果不做修改,不可能用到attention里</p></div><h2 id="具体例子"><a href="#具体例子" class="headerlink" title="具体例子"></a>具体例子</h2><p>假设我们有一个隐状态 $[1.5,2,0.5]^T$,计算过程如下:</p><p>对所有元素大小排序 $2>1.5>0.5$</p><p>遍历 $|S(z)|$,从 $|S(z)|=1$ 开始,求出 $\mu = \frac{2-1}{1} = 1$</p><p>验证是否满足条件:</p><script type="math/tex; mode=display">\left\{\begin{aligned}2 - 1 > 0 \\1.5 - 1 > 0 \\0.5 - 1 < 0\end{aligned}\right.</script><p>不满足,继续搜,$|S(z)|=2$,$\mu = \frac{2+1.5-1}{2} = 1.25$</p><p>验证是否满足条件:</p><script type="math/tex; mode=display">\left\{\begin{aligned}2 - 1.25 > 0 \\1.5 - 1.25 > 0 \\0.5 - 1.25 < 0\end{aligned}\right.</script><p>满足了,所以 $\text{sparsemax}(z) = [0.25, 0.75, 0]$</p>]]></content>
<summary type="html">凸优化没白学</summary>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/categories/Model-Structure/"/>
<category term="Model Structure" scheme="https://anti-entrophic.github.io/tags/Model-Structure/"/>
<category term="softmax" scheme="https://anti-entrophic.github.io/tags/softmax/"/>
</entry>
</feed>