-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathflow_matching.py
More file actions
375 lines (299 loc) · 13.1 KB
/
flow_matching.py
File metadata and controls
375 lines (299 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
from random import random, choice
import numpy as np
import tyro
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import yaml
from einops import rearrange, repeat
from lightning import Trainer
from lightning.pytorch import LightningModule
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from diffusion.module.utils.biovid import exists, BioVidDM
from diffusion.module.temporal_latent_unet import Unet1D as VideoLatentUnet
from diffusion.module.utils.ema import EMA
from diffusion.module.control_embedding import ControlEmbedding
import os
from diffusion.module.utils.misc import enlarge_as
from diffusion.module.utils.biovid import bilateral_filter
@dataclass
class TrainingConfig:
config_path: str = '/home/tien/paindiffusion/configure/flow_matching.yml'
sample_output_dir: str = 'samples'
max_epochs: int = 30
learning_rate: float = 1e-4
drop_probs: list = field(default_factory=lambda: [0.3, 0.3, 0.3])
batch_size: int = 32
num_workers: int = 4
frame_stack: int = 4
devices: int = 1
train: bool = True
validate: bool = True
test: bool = False
load_from_checkpoint: bool = False
checkpoint_path: Optional[str] = None
fast_check: bool = False
use_logger: bool = True
def generate_pyramid_scheduling_matrix(horizon: int, uncertainty_scale: float, sampling_timesteps: int) -> np.ndarray:
height = sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
for m in range(height):
for t in range(horizon):
scheduling_matrix[m, t] = sampling_timesteps + int(t * uncertainty_scale) - m
return torch.from_numpy(np.clip(scheduling_matrix, 0, sampling_timesteps))
class FlowMatchingModule(LightningModule):
def __init__(self, net, ctrl_emb, lr=0.001, frame_stack=4, drop_probs=[0.3, 0.3, 0.3], sample_output_dir='samples', guide=[1, 1, 1]):
super().__init__()
self.net:VideoLatentUnet = net
self.ctrl_emb:ControlEmbedding = ctrl_emb
self.lr = lr
self.frame_stack = frame_stack
self.drop_probs = drop_probs
self.criterion = nn.MSELoss(reduction='mean')
self.sample_output_dir = sample_output_dir
self.val_outs = None
self.guide = guide
def forward(self, x, t, ctrl=None, cfg=False):
if ctrl is not None:
if self.training:
ctrl = self.ctrl_emb.dropout_ctrl(ctrl, self.drop_probs)
ctrl = self.ctrl_emb(ctrl, is_training=self.training)
if not cfg:
ctrl = rearrange(ctrl, "c b l fs d -> b l (c fs) d").contiguous() if exists(ctrl) else ctrl
return self.net(x, t, ctrl=ctrl)
else:
return self.net.classifier_free_guidance(x, t, ctrl=ctrl, guide=self.guide)
def prepare_batch(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
x_0 = batch['x']
ctrl = batch['ctrl']
batch_size, num_frames, *_ = x_0.shape
x_0 = rearrange(x_0, "b (t fs) d -> b t fs d", fs=self.frame_stack).contiguous()
if exists(ctrl):
for idx, cond in enumerate(ctrl):
if len(cond.shape) != 2:
continue
ctrl[idx] = rearrange(cond, "b (t fs) -> b t fs", fs=self.frame_stack).contiguous()
mask_weight = torch.ones(batch_size, num_frames).to(x_0.device)
return x_0, ctrl, mask_weight
def training_step(self, batch, batch_idx):
x_1, ctrl, mask_weight = self.prepare_batch(batch)
t = torch.rand((x_1.shape[0], x_1.shape[1]), device=self.device)
# t = t.unsqueeze(-1).repeat(1, x_1.shape[1])
x_0 = torch.randn_like(x_1)
t_expanded = enlarge_as(t, x_1)
x_t = (1 - t_expanded) * x_0 + t_expanded * x_1
dx_t = x_1 - x_0
pred = self(x_t, t, ctrl=ctrl)
loss = self.criterion(pred, dx_t)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x_1, ctrl, mask_weight = self.prepare_batch(batch)
t = torch.rand((x_1.shape[0], x_1.shape[1]), device=self.device)
# t = t.unsqueeze(-1).repeat(1, x_1.shape[1])
x_0 = torch.randn_like(x_1)
t_expanded = enlarge_as(t, x_1)
x_t = (1 - t_expanded) * x_0 + t_expanded * x_1
dx_t = x_1 - x_0
pred = self(x_t, t, ctrl=ctrl)
loss = self.criterion(pred, dx_t)
self.log('val_loss', loss)
self.val_outs = batch
return loss
def flow_step(self, x_t, t_start, t_end, ctrl):
# t_start = enlarge_as(t_start, x_t)
dt = t_end - t_start
half_dt = dt / 2
# First inner step
v1 = self(x_t, t_start, ctrl=ctrl, cfg=True)
x_mid = x_t + v1 * enlarge_as(half_dt, v1)
# Second step from midpoint
v2 = self(x_mid, t_start + half_dt, ctrl=ctrl, cfg=True)
# Final position
return x_t + enlarge_as(dt, v2) * v2
def sample_imgs(self, batch, save=False, n_steps=4):
x_0, ctrl, mask_weight = self.prepare_batch(batch)
window_size = 16
context = 4
time_steps = torch.linspace(0, 1.0, n_steps+1, device=self.device)
video_length = x_0.shape[1]
curr_frame = 0
chunk_size = window_size - context
x_preds = None
while curr_frame < video_length:
if chunk_size > 0:
horizon = min(chunk_size, video_length - curr_frame)
else:
horizon = video_length - curr_frame
with_context = curr_frame >= context
window_scheduling_matrix = generate_pyramid_scheduling_matrix(
horizon=horizon,
uncertainty_scale=1,
sampling_timesteps=n_steps
)
# flip the scheduling matrix
window_scheduling_matrix = torch.flip(window_scheduling_matrix, [0, 1])
start_frame = max(0, curr_frame + horizon - window_size)
end_frame = min(video_length, curr_frame + horizon)
window_x = torch.randn(
(x_0.shape[0], horizon, self.frame_stack, 128),
device=self.device,
)
if exists(x_preds):
window_x = torch.cat((x_preds[:,start_frame:], window_x), dim=1)
window_ctrl = [ctrl[i][:, start_frame:end_frame] for i in range(len(ctrl))] if exists(ctrl) else None
current_context_size = end_frame - start_frame - horizon
window_x = self.sample_a_chunk(window_scheduling_matrix, window_x, window_ctrl, time_steps, with_context, context_size=current_context_size)
x_preds = torch.cat((x_preds, window_x[:, -horizon:]), dim=1) if exists(x_preds) else window_x
curr_frame += horizon
x_preds[..., :3] /= 100
x_preds = rearrange(x_preds, "b t fs d -> b (t fs) d")
saved_object = {
'x' : x_preds,
'ctrl' : ctrl,
"start_frame_id":batch['start_frame_id'],
"end_frame_id":batch['end_frame_id'],
'video_name':batch['video_name'],
}
try:
current_step = self.trainer.global_step
except:
current_step = 0
if save:
torch.save(saved_object, f"{self.sample_output_dir}/{current_step}_with_ctrl.pt")
return saved_object
def sample_a_chunk(self, scheduling_matrix, window_x, window_ctrl, time_steps, with_context, context_size):
sigma = [1e-4, 2e-4]
for m in range(scheduling_matrix.shape[0]-1):
current_noise = scheduling_matrix[m]
if m == scheduling_matrix.shape[0]-1:
next_noise = torch.ones(scheduling_matrix[m].shape, device=self.device) * (len(time_steps) - 1)
else:
next_noise = scheduling_matrix[m+1]
current_noise = time_steps[current_noise.int()]
next_noise = time_steps[next_noise.int()]
# add little noise to the context frames
if with_context:
context_current_noise = torch.ones(context_size, device=self.device) - choice(sigma)
context_next_noise = torch.ones(context_size, device=self.device) - choice(sigma)
current_noise = torch.cat((context_current_noise, current_noise), dim=0).to(window_x.device)
next_noise = torch.cat((context_next_noise, next_noise), dim=0).to(window_x.device)
# add batch dimension
current_noise = repeat(current_noise, "t -> b t", b=window_x.shape[0])
next_noise = repeat(next_noise, "t -> b t", b=window_x.shape[0])
window_x = self.flow_step(x_t=window_x, t_start=current_noise, t_end=next_noise, ctrl=window_ctrl)
return window_x
def on_validation_epoch_end(self):
self.sample_imgs(self.val_outs, save=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def main(config: TrainingConfig):
# Load model config
with open(config.config_path, 'r') as f:
model_config = yaml.safe_load(f)
# Setup logging
run_name = model_config.get('RUN_NAME', 'flow_matching_run')
wandb_logger = None
if config.use_logger and not config.fast_check:
wandb_logger = WandbLogger(
project="diffusion_pain_flow_matching",
name=run_name
)
wandb_logger.log_hyperparams(model_config)
# Create necessary directories
dirs = [
model_config.get('DIFFUSION', {}).get('sample_output_dir', 'samples'),
model_config.get('CHECKPOINT', 'checkpoints'),
model_config.get('CODEBACKUP', 'codebackup')
]
for dir in dirs:
os.makedirs(dir, exist_ok=True)
# Initialize data module
biovid = BioVidDM.from_conf(config.config_path)
biovid.setup('fit')
# Initialize models
net = VideoLatentUnet(**model_config['MODEL'])
ctrl_emb = ControlEmbedding(
emb_dim=128,
in_dim=1,
conf=model_config['DATA_STATS']
)
# Create flow matching module
fm_model = FlowMatchingModule(
net=net,
ctrl_emb=ctrl_emb,
lr=config.learning_rate,
frame_stack=config.frame_stack,
drop_probs=config.drop_probs,
sample_output_dir=config.sample_output_dir,
guide=model_config['DIFFUSION'].get('guide', [1, 1, 1])
)
# Setup callbacks
callbacks = []
# Checkpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='train_loss',
dirpath=model_config.get('CHECKPOINT', 'checkpoints'),
filename='flow_matching-{epoch:02d}-{train_loss:.2f}',
save_top_k=2,
mode='min',
save_last=True
)
callbacks.append(checkpoint_callback)
# EMA callback
ema = EMA(
decay=0.9999,
evaluate_ema_weights_instead=True,
save_ema_weights_in_callback_state=True
)
callbacks.append(ema)
# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
torch.set_float32_matmul_precision("highest")
# Initialize trainer
trainer = Trainer(
max_epochs=config.max_epochs,
accelerator='gpu',
devices=config.devices if config.train else 1,
strategy='ddp_find_unused_parameters_true' if config.devices > 1 else 'auto',
logger=wandb_logger,
enable_checkpointing=True,
callbacks=callbacks,
check_val_every_n_epoch=1,
fast_dev_run=5 if config.fast_check else False,
)
# Training/validation/testing flow
if config.train:
trainer.fit(
fm_model,
datamodule=biovid,
ckpt_path=config.checkpoint_path if config.load_from_checkpoint else None
)
# Save best checkpoint info
best_ckpt = checkpoint_callback.best_model_path
last_ckpt = checkpoint_callback.last_model_path
if config.use_logger:
wandb_logger.log_hyperparams({
"best_ckpt": best_ckpt,
'last_ckpt': last_ckpt
})
# Update config file with checkpoint paths
with open(config.config_path, 'r') as f:
conf = yaml.safe_load(f)
conf['BEST_CKPT'] = best_ckpt
conf['LAST_CKPT'] = last_ckpt
with open(config.config_path, 'w') as f:
yaml.safe_dump(conf, f)
if config.validate:
trainer.validate(
fm_model,
datamodule=biovid,
ckpt_path=config.checkpoint_path if config.load_from_checkpoint else None
)
if __name__ == "__main__":
tyro.cli(main)