-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathtest_rst.py
More file actions
302 lines (234 loc) · 10.1 KB
/
test_rst.py
File metadata and controls
302 lines (234 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
This file contains the tests corresponding to the extra usage examples contained in the `.rst` files
of the documentation. When there are multiple examples within a single `.rst` file, we use nested
functions here to test them.
"""
def test_basic_usage():
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
import torchjd
from torchjd.aggregation import UPGrad
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
optimizer = SGD(model.parameters(), lr=0.1)
aggregator = UPGrad()
input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
target1 = torch.randn(16) # First batch of 16 targets
target2 = torch.randn(16) # Second batch of 16 targets
loss_fn = MSELoss()
output = model(input)
loss1 = loss_fn(output[:, 0], target1)
loss2 = loss_fn(output[:, 1], target2)
optimizer.zero_grad()
torchjd._autojac.backward([loss1, loss2], aggregator)
optimizer.step()
def test_iwrm():
def test_erm_with_sgd():
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16, 1)
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1))
loss_fn = MSELoss()
params = model.parameters()
optimizer = SGD(params, lr=0.1)
for x, y in zip(X, Y):
y_hat = model(x)
loss = loss_fn(y_hat, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test_iwrm_with_ssjd():
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd._autojac import backward
from torchjd.aggregation import UPGrad
X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16, 1)
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 1))
loss_fn = MSELoss(reduction="none")
params = model.parameters()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
for x, y in zip(X, Y):
y_hat = model(x)
losses = loss_fn(y_hat, y)
optimizer.zero_grad()
backward(losses, aggregator)
optimizer.step()
test_erm_with_sgd()
test_iwrm_with_ssjd()
def test_mtl():
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
def test_lightning_integration():
# Extra ----------------------------------------------------------------------------------------
import logging
import warnings
warnings.filterwarnings("ignore")
logging.disable(logging.INFO)
# ----------------------------------------------------------------------------------------------
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.nn import Linear, ReLU, Sequential
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad
class Model(LightningModule):
def __init__(self):
super().__init__()
self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
self.task1_head = Linear(3, 1)
self.task2_head = Linear(3, 1)
self.automatic_optimization = False
def training_step(self, batch, batch_idx) -> None:
input, target1, target2 = batch
features = self.feature_extractor(input)
output1 = self.task1_head(features)
output2 = self.task2_head(features)
loss1 = mse_loss(output1, target1)
loss2 = mse_loss(output2, target2)
opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()
def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
return optimizer
model = Model()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
dataset = TensorDataset(inputs, task1_targets, task2_targets)
train_loader = DataLoader(dataset)
trainer = Trainer(
accelerator="cpu",
max_epochs=1,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
)
trainer.fit(model=model, train_dataloaders=train_loader)
def test_rnn():
import torch
from torch.nn import RNN
from torch.optim import SGD
from torchjd._autojac import backward
from torchjd.aggregation import UPGrad
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()
def test_monitoring():
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.nn.functional import cosine_similarity
from torch.optim import SGD
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad
def print_weights(_, __, weights: torch.Tensor) -> None:
"""Prints the extracted weights."""
print(f"Weights: {weights}")
def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
"""Prints the cosine similarity between the aggregation and the average gradient."""
matrix = inputs[0]
gd_output = matrix.mean(dim=0)
similarity = cosine_similarity(aggregation, gd_output, dim=0)
print(f"Cosine similarity: {similarity.item():.4f}")
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
aggregator.weighting.register_forward_hook(print_weights)
aggregator.register_forward_hook(print_gd_similarity)
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
def test_amp():
import torch
from torch.amp import GradScaler
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd._autojac import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
scaler = GradScaler(device="cpu")
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
with torch.autocast(device_type="cpu", dtype=torch.float16):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()