-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplus.py
More file actions
130 lines (117 loc) · 5.64 KB
/
splus.py
File metadata and controls
130 lines (117 loc) · 5.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from torch.optim.optimizer import Optimizer
from typing import Any
class SPlus(Optimizer):
def __init__(
self,
params: Any,
lr: float = 1e-1,
b1: float = 0.9,
b2: float = 0.999,
ema_rate: float = 0.999,
inverse_every: int = 100,
eps: float = 1e-30,
weight_decay: float = 1e-2,
max_dim: int = 10000,
nonstandard_constant: float = 0.001,
):
defaults = dict(
lr=lr,
b1=b1,
b2=b2,
ema_rate=ema_rate,
inverse_every=inverse_every,
weight_decay=weight_decay,
eps=eps,
max_dim=max_dim,
nonstandard_constant=nonstandard_constant,
)
super(SPlus, self).__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
for p in group["params"]:
p_state = self.state.get(p, [])
step_val = float(p_state["step"])
p_state["step"] = torch.tensor(step_val)
@torch.no_grad()
def eval(self): # Sets parameters to EMA values for evaluation.
for group in self.param_groups:
if 'train_mode' in group:
train_mode = group['train_mode']
ema_rate = group['ema_rate']
if train_mode:
for p in group['params']:
state = self.state[p]
if len(state) == 0 or 'ema' not in state:
continue
state['param_buffer'] = p.clone()
p.lerp_(state['ema'], 1)
p.mul_(1 / (1 - ema_rate ** state['step']))
group['train_mode'] = False
@torch.no_grad()
def train(self): # Resets parameters back from buffer.
for group in self.param_groups:
if 'train_mode' in group:
train_mode = group['train_mode']
if not train_mode:
for p in group['params']:
state = self.state[p]
if 'param_buffer' in state:
p.lerp_(state['param_buffer'], 1) # p.copy_(state['param_buffer'])
del state['param_buffer']
group['train_mode'] = True
@torch.no_grad()
def step(self, closure = None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
if len(state) == 0: # Initialization
state["step"] = torch.tensor(0.0)
state['momentum'] = torch.zeros_like(p)
state['ema'] = torch.zeros_like(p)
group['train_mode'] = True
if len(p.shape) == 2:
state['sides'] = [torch.zeros((d, d), device=p.device) if d < group['max_dim'] else None for d in p.shape]
state['q_sides'] = [torch.eye(d, device=p.device) if d < group['max_dim'] else None for d in p.shape]
# Shape-dependent scaling
if len(p.shape) != 2 or p.shape[0] > group['max_dim'] or p.shape[1] > group['max_dim']:
scaled_lr = group['lr'] * group['nonstandard_constant']
else:
scaled_lr = group['lr'] * (2 / (p.shape[0] + p.shape[1]))
# Main splus update
state['step'] += 1
m = state['momentum']
m.lerp_(grad, 1-group["b1"])
if len(p.shape) == 2:
m = state['q_sides'][0].T @ m if state['q_sides'][0] is not None else m
m = m @ state['q_sides'][1] if state['q_sides'][1] is not None else m
state['sides'][0] = torch.lerp(state['sides'][0], grad @ grad.T, 1 - group['b2']) if state['sides'][0] is not None else None
state['sides'][1] = torch.lerp(state['sides'][1], grad.T @ grad, 1 - group['b2']) if state['sides'][1] is not None else None
u = torch.sgn(m)
u = state['q_sides'][0] @ u if state['q_sides'][0] is not None else u
u = u @ state['q_sides'][1].T if state['q_sides'][1] is not None else u
# Every `inverse_every` steps, we update the inverse eigendecomposition.
try:
if state['step'] == 1 or state['step'] % group['inverse_every'] == 0:
if state['sides'][0] is not None:
_, eigvecs = torch.linalg.eigh(state['sides'][0] + group['eps'] * torch.eye(state['sides'][0].shape[0], device=p.device))
state['q_sides'][0] = eigvecs
if state['sides'][1] is not None:
_, eigvecs = torch.linalg.eigh(state['sides'][1] + group['eps'] * torch.eye(state['sides'][1].shape[0], device=p.device))
state['q_sides'][1] = eigvecs
except Exception as e:
# If the eigendecomposition fails, return infinite loss
raise RuntimeError(f"Failed to compute eigendecomposition: {e}")
else:
u = torch.sgn(m)
p.add_(u, alpha=-scaled_lr)
state['ema'].lerp_(p, 1 - group['ema_rate'])
p.mul_(1 - scaled_lr * group["weight_decay"])
return loss