-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathcdmf_trainer.py
More file actions
1366 lines (1184 loc) · 50.3 KB
/
cdmf_trainer.py
File metadata and controls
1366 lines (1184 loc) · 50.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# C:\AceForge\cdmf_trainer.py
# Customized version of the ACE-Step trainer.py script.
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer
from datetime import datetime
from pathlib import Path
import argparse
import torch
import json
import matplotlib
import torch.nn.functional as F
import torch.utils.data
from pytorch_lightning.core import LightningModule
from torch.utils.data import DataLoader
from acestep.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler,
)
from cdmf_text2music_dataset import Text2MusicDataset
from loguru import logger
from transformers import AutoModel, Wav2Vec2FeatureExtractor
import torchaudio
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
retrieve_timesteps,
)
from diffusers.utils.torch_utils import randn_tensor
from acestep.apg_guidance import apg_forward, MomentumBuffer
from tqdm import tqdm
import random
import os
from cdmf_pipeline_ace_step import ACEStepPipeline
from cdmf_paths import CUSTOM_LORA_ROOT
matplotlib.use("Agg")
# Configure CUDA backends if available
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = False
torch.set_float32_matmul_precision("high")
class Pipeline(LightningModule):
def __init__(
self,
learning_rate: float = 1e-4,
num_workers: int = 4,
train: bool = True,
T: int = 1000,
weight_decay: float = 1e-2,
every_plot_step: int = 2000,
shift: float = 3.0,
logit_mean: float = 0.0,
logit_std: float = 1.0,
timestep_densities_type: str = "logit_normal",
ssl_coeff: float = 1.0,
instrumental_only: bool = False,
checkpoint_dir=None,
max_steps: int = 200000,
warmup_steps: int = 10,
dataset_path: str = "./data/your_dataset_path",
lora_config_path: str = None,
adapter_name: str = "lora_adapter",
max_audio_seconds: float = 60.0,
lora_save_every: int = 0,
):
super().__init__()
self.save_hyperparameters()
self.is_train = train
self.T = T
# Initialize scheduler
self.scheduler = self.get_scheduler()
# step 1: load model
acestep_pipeline = ACEStepPipeline(checkpoint_dir)
acestep_pipeline.load_checkpoint(acestep_pipeline.checkpoint_dir)
transformers = acestep_pipeline.ace_step_transformer.float().cpu()
transformers.enable_gradient_checkpointing()
assert lora_config_path is not None, "Please provide a LoRA config path"
if lora_config_path is not None:
try:
from peft import LoraConfig
except ImportError:
raise ImportError("Please install peft library to use LoRA training")
with open(lora_config_path, encoding="utf-8") as f:
import json
lora_config = json.load(f)
lora_config = LoraConfig(**lora_config)
transformers.add_adapter(adapter_config=lora_config, adapter_name=adapter_name)
self.adapter_name = adapter_name
self.transformers = transformers
# Explicitly freeze base weights and unfreeze only LoRA parameters.
# If instrumental_only is enabled, we also freeze LoRA weights
# attached to lyric / speaker / vocal-specific submodules so they
# are not trained on instrumental-only datasets.
trainable_params = 0
frozen_params = 0
def _is_vocalish_param(param_name: str) -> bool:
lower = param_name.lower()
# Lyric encoder stack + projection
if "lyric_encoder" in lower or "lyric_proj" in lower:
return True
# Explicit speaker embeddings
if "speaker_embedder" in lower:
return True
# Future-proof catch-all: anything explicitly marked vocal/speech
if "vocal" in lower or "speech" in lower:
return True
return False
instrumental_only_flag = getattr(self.hparams, "instrumental_only", False)
for name, p in self.transformers.named_parameters():
is_lora_param = "lora" in name.lower()
# Base weights are always frozen; we only ever train LoRA params.
if not is_lora_param:
p.requires_grad_(False)
frozen_params += p.numel()
continue
# LoRA params on vocal/lyric-related modules: freeze when in instrumental mode.
if instrumental_only_flag and _is_vocalish_param(name):
p.requires_grad_(False)
frozen_params += p.numel()
else:
p.requires_grad_(True)
trainable_params += p.numel()
logger.info(
f"[Pipeline.__init__] LoRA setup (instrumental_only={instrumental_only_flag}): "
f"trainable_params={trainable_params}, frozen_params={frozen_params}"
)
# DEBUG: how many transformer params are actually marked trainable?
trainable = 0
frozen = 0
for name, p in self.transformers.named_parameters():
if p.requires_grad:
trainable += p.numel()
else:
frozen += p.numel()
logger.info(
f"[Pipeline.__init__] transformers params: trainable={trainable}, "
f"frozen={frozen}"
)
self.dcae = acestep_pipeline.music_dcae.float().cpu()
self.dcae.requires_grad_(False)
self.text_encoder_model = acestep_pipeline.text_encoder_model.float().cpu()
self.text_encoder_model.requires_grad_(False)
self.text_tokenizer = acestep_pipeline.text_tokenizer
if self.is_train:
self.transformers.train()
# download first
try:
self.mert_model = AutoModel.from_pretrained(
"m-a-p/MERT-v1-330M", trust_remote_code=True, cache_dir=checkpoint_dir
).eval()
except:
import json
import os
mert_config_path = os.path.join(
os.path.expanduser("~"),
".cache",
"huggingface",
"hub",
"models--m-a-p--MERT-v1-330M",
"blobs",
"14f770758c7fe5c5e8ead4fe0f8e5fa727eb6942"
)
with open(mert_config_path) as f:
mert_config = json.load(f)
mert_config["conv_pos_batch_norm"] = False
with open(mert_config_path, mode="w") as f:
json.dump(mert_config, f)
self.mert_model = AutoModel.from_pretrained(
"m-a-p/MERT-v1-330M", trust_remote_code=True, cache_dir=checkpoint_dir
).eval()
self.mert_model.requires_grad_(False)
self.resampler_mert = torchaudio.transforms.Resample(
orig_freq=48000, new_freq=24000
)
self.processor_mert = Wav2Vec2FeatureExtractor.from_pretrained(
"m-a-p/MERT-v1-330M", trust_remote_code=True
)
self.hubert_model = AutoModel.from_pretrained("utter-project/mHuBERT-147").eval()
self.hubert_model.requires_grad_(False)
self.resampler_mhubert = torchaudio.transforms.Resample(
orig_freq=48000, new_freq=16000
)
self.processor_mhubert = Wav2Vec2FeatureExtractor.from_pretrained(
"utter-project/mHuBERT-147",
cache_dir=checkpoint_dir,
)
self.ssl_coeff = ssl_coeff
def _prune_lightning_checkpoints(self, keep: int = 1) -> None:
"""
Keep at most `keep` .ckpt files in this run's Lightning checkpoints
folder, deleting older ones by modification time.
This prevents disk usage from exploding while still leaving a recent
checkpoint available for manual resume/debugging.
"""
try:
log_dir = getattr(self.logger, "log_dir", None)
except Exception:
log_dir = None
if not log_dir:
return
ckpt_dir = Path(log_dir) / "checkpoints"
if not ckpt_dir.is_dir():
return
try:
ckpts = sorted(
ckpt_dir.glob("*.ckpt"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
except OSError:
return
for old in ckpts[keep:]:
try:
old.unlink()
except OSError:
# If a delete fails (e.g. in use), just skip it.
pass
def infer_mert_ssl(self, target_wavs, wav_lengths):
# Input is N x 2 x T (48kHz), convert to N x T (24kHz), mono
mert_input_wavs_mono_24k = self.resampler_mert(target_wavs.mean(dim=1))
bsz = target_wavs.shape[0]
actual_lengths_24k = wav_lengths // 2 # 48kHz -> 24kHz
# Normalize the actual audio part
means = torch.stack(
[
mert_input_wavs_mono_24k[i, : actual_lengths_24k[i]].mean()
for i in range(bsz)
]
)
vars = torch.stack(
[
mert_input_wavs_mono_24k[i, : actual_lengths_24k[i]].var()
for i in range(bsz)
]
)
mert_input_wavs_mono_24k = (
mert_input_wavs_mono_24k - means.view(-1, 1)
) / torch.sqrt(vars.view(-1, 1) + 1e-7)
# MERT SSL constraint
# Define the length of each chunk (5 seconds of samples)
chunk_size = 24000 * 5 # 5 seconds, 24000 samples per second
total_length = mert_input_wavs_mono_24k.shape[1]
num_chunks_per_audio = (actual_lengths_24k + chunk_size - 1) // chunk_size
# Process chunks
all_chunks = []
chunk_actual_lengths = []
for i in range(bsz):
audio = mert_input_wavs_mono_24k[i]
actual_length = actual_lengths_24k[i]
for start in range(0, actual_length, chunk_size):
end = min(start + chunk_size, actual_length)
chunk = audio[start:end]
if len(chunk) < chunk_size:
chunk = F.pad(
chunk, (0, chunk_size - len(chunk))
) # Pad insufficient parts with zeros
all_chunks.append(chunk)
chunk_actual_lengths.append(end - start)
# Stack all chunks to (total_chunks, chunk_size)
all_chunks = torch.stack(all_chunks, dim=0)
# Batch inference
with torch.no_grad():
# Output shape: (total_chunks, seq_len, hidden_size)
mert_ssl_hidden_states = self.mert_model(all_chunks).last_hidden_state
# Calculate the number of features for each chunk
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
# Trim the hidden states of each chunk
chunk_hidden_states = [
mert_ssl_hidden_states[i, : chunk_num_features[i], :]
for i in range(len(all_chunks))
]
# Organize hidden states by audio
mert_ssl_hidden_states_list = []
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[
chunk_idx : chunk_idx + num_chunks_per_audio[i]
]
audio_hidden = torch.cat(
audio_chunks, dim=0
) # Concatenate chunks of the same audio
mert_ssl_hidden_states_list.append(audio_hidden)
chunk_idx += num_chunks_per_audio[i]
return mert_ssl_hidden_states_list
def infer_mhubert_ssl(self, target_wavs, wav_lengths):
# Step 1: Preprocess audio
# Input: N x 2 x T (48kHz, stereo) -> N x T (16kHz, mono)
mhubert_input_wavs_mono_16k = self.resampler_mhubert(target_wavs.mean(dim=1))
bsz = target_wavs.shape[0]
actual_lengths_16k = wav_lengths // 3 # Convert lengths from 48kHz to 16kHz
# Step 2: Zero-mean unit-variance normalization (only on actual audio)
means = torch.stack(
[
mhubert_input_wavs_mono_16k[i, : actual_lengths_16k[i]].mean()
for i in range(bsz)
]
)
vars = torch.stack(
[
mhubert_input_wavs_mono_16k[i, : actual_lengths_16k[i]].var()
for i in range(bsz)
]
)
mhubert_input_wavs_mono_16k = (
mhubert_input_wavs_mono_16k - means.view(-1, 1)
) / torch.sqrt(vars.view(-1, 1) + 1e-7)
# Step 3: Define chunk size for MHubert (30 seconds at 16kHz)
chunk_size = 16000 * 30 # 30 seconds = 480,000 samples
# Step 4: Split audio into chunks
num_chunks_per_audio = (
actual_lengths_16k + chunk_size - 1
) // chunk_size # Ceiling division
all_chunks = []
chunk_actual_lengths = []
for i in range(bsz):
audio = mhubert_input_wavs_mono_16k[i]
actual_length = actual_lengths_16k[i]
for start in range(0, actual_length, chunk_size):
end = min(start + chunk_size, actual_length)
chunk = audio[start:end]
if len(chunk) < chunk_size:
chunk = F.pad(chunk, (0, chunk_size - len(chunk))) # Pad with zeros
all_chunks.append(chunk)
chunk_actual_lengths.append(end - start)
# Step 5: Stack all chunks for batch inference
all_chunks = torch.stack(all_chunks, dim=0) # Shape: (total_chunks, chunk_size)
# Step 6: Batch inference with MHubert model
with torch.no_grad():
mhubert_ssl_hidden_states = self.hubert_model(all_chunks).last_hidden_state
# Shape: (total_chunks, seq_len, hidden_size)
# Step 7: Compute number of features per chunk (assuming model stride of 320)
chunk_num_features = [(length + 319) // 320 for length in chunk_actual_lengths]
# Step 8: Trim hidden states to remove padding effects
chunk_hidden_states = [
mhubert_ssl_hidden_states[i, : chunk_num_features[i], :]
for i in range(len(all_chunks))
]
# Step 9: Reorganize hidden states by original audio
mhubert_ssl_hidden_states_list = []
chunk_idx = 0
for i in range(bsz):
audio_chunks = chunk_hidden_states[
chunk_idx : chunk_idx + num_chunks_per_audio[i]
]
audio_hidden = torch.cat(
audio_chunks, dim=0
) # Concatenate chunks for this audio
mhubert_ssl_hidden_states_list.append(audio_hidden)
chunk_idx += num_chunks_per_audio[i]
return mhubert_ssl_hidden_states_list
def get_text_embeddings(self, texts, device, text_max_length=256):
from loguru import logger
logger.info("[get_text_embeddings] start (real encoder)")
# Tokenize on CPU
inputs = self.text_tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=text_max_length,
)
# Move token ids + masks to the target device (GPU)
inputs = {key: value.to(device) for key, value in inputs.items()}
# Make sure the encoder itself is on the same device
if self.text_encoder_model.device != device:
logger.info(
f"[get_text_embeddings] moving text_encoder_model to device={device}"
)
self.text_encoder_model.to(device)
with torch.no_grad():
outputs = self.text_encoder_model(**inputs)
last_hidden_states = outputs.last_hidden_state
attention_mask = inputs["attention_mask"]
logger.info(
f"[get_text_embeddings] done, hidden_states.shape={tuple(last_hidden_states.shape)}"
)
return last_hidden_states, attention_mask
def preprocess(self, batch, train=True):
logger.info("[preprocess] start")
target_wavs = batch["target_wavs"]
wav_lengths = batch["wav_lengths"]
dtype = target_wavs.dtype
bs = target_wavs.shape[0]
device = target_wavs.device
# ------------------------------------------------------------------
# 1) Optional cropping: avoid feeding entire multi-minute songs
# ------------------------------------------------------------------
max_audio_seconds = getattr(self.hparams, "max_audio_seconds", 60.0)
if max_audio_seconds is not None and max_audio_seconds > 0:
max_samples_48k = int(max_audio_seconds * 48000)
if target_wavs.shape[-1] > max_samples_48k:
old_samples = target_wavs.shape[-1]
logger.info(
f"[preprocess] cropping audio from {old_samples} to "
f"{max_samples_48k} samples "
f"({old_samples / 48000.0:.1f}s -> {max_audio_seconds:.1f}s)"
)
# Trim time dimension
target_wavs = target_wavs[..., :max_samples_48k]
# Clamp reported lengths
wav_lengths = torch.clamp(wav_lengths, max=max_samples_48k)
# ------------------------------------------------------------------
# 2) SSL features (MERT + mHuBERT), controlled by ssl_coeff
# ------------------------------------------------------------------
mert_ssl_hidden_states = None
mhubert_ssl_hidden_states = None
if train and self.ssl_coeff > 0:
logger.info(
f"[preprocess] SSL ENABLED (ssl_coeff={self.ssl_coeff}); "
"running MERT/mHuBERT"
)
# Use device-agnostic autocast - get device type from tensor device
device_type = device.type if device.type in ["cuda", "cpu", "mps"] else "cpu"
with torch.amp.autocast(device_type=device_type, dtype=dtype):
mert_ssl_hidden_states = self.infer_mert_ssl(
target_wavs, wav_lengths
)
mhubert_ssl_hidden_states = self.infer_mhubert_ssl(
target_wavs, wav_lengths
)
else:
logger.info(
f"[preprocess] SSL DISABLED (train={train}, ssl_coeff={self.ssl_coeff}); "
"skipping MERT/mHuBERT"
)
# ------------------------------------------------------------------
# 3) Text embeddings (real encoder now)
# ------------------------------------------------------------------
texts = batch["prompts"]
logger.info(
f"[preprocess] before text encoder; batch_size={bs}, "
f"num_chars_first_prompt={len(texts[0]) if len(texts) > 0 else 0}"
)
encoder_text_hidden_states, text_attention_mask = self.get_text_embeddings(
texts, device
)
encoder_text_hidden_states = encoder_text_hidden_states.to(dtype)
logger.info(
"[preprocess] after text encoder; "
f"encoder_text_hidden_states.shape={tuple(encoder_text_hidden_states.shape)}"
)
# ------------------------------------------------------------------
# 4) DCAE encode to latents (real, not fake)
# ------------------------------------------------------------------
logger.info(
"[preprocess] before DCAE.encode; "
f"target_wavs.shape={tuple(target_wavs.shape)}, "
f"wav_lengths[0]={int(wav_lengths[0].item()) if wav_lengths.numel() > 0 else -1}"
)
target_latents, _ = self.dcae.encode(target_wavs, wav_lengths)
logger.info(
"[preprocess] after DCAE.encode; "
f"target_latents.shape={tuple(target_latents.shape)}"
)
attention_mask = torch.ones(
bs, target_latents.shape[-1], device=device, dtype=dtype
)
speaker_embds = batch["speaker_embs"].to(dtype)
keys = batch["keys"]
lyric_token_ids = batch["lyric_token_ids"]
lyric_mask = batch["lyric_masks"]
# ------------------------------------------------------------------
# 5) Classifier-free guidance masks
# ------------------------------------------------------------------
if train:
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bs,), device=device) < 0.15),
torch.zeros(size=(bs,), device=device),
torch.ones(size=(bs,), device=device),
).long()
# N x T x 768
encoder_text_hidden_states = torch.where(
full_cfg_condition_mask.unsqueeze(1).unsqueeze(1).bool(),
encoder_text_hidden_states,
torch.zeros_like(encoder_text_hidden_states),
)
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bs,), device=device) < 0.50),
torch.zeros(size=(bs,), device=device),
torch.ones(size=(bs,), device=device),
).long()
# N x 512
speaker_embds = torch.where(
full_cfg_condition_mask.unsqueeze(1).bool(),
speaker_embds,
torch.zeros_like(speaker_embds),
)
# Lyrics
full_cfg_condition_mask = torch.where(
(torch.rand(size=(bs,), device=device) < 0.15),
torch.zeros(size=(bs,), device=device),
torch.ones(size=(bs,), device=device),
).long()
lyric_token_ids = torch.where(
full_cfg_condition_mask.unsqueeze(1).bool(),
lyric_token_ids,
torch.zeros_like(lyric_token_ids),
)
lyric_mask = torch.where(
full_cfg_condition_mask.unsqueeze(1).bool(),
lyric_mask,
torch.zeros_like(lyric_mask),
)
logger.info("[preprocess] done")
return (
keys,
target_latents,
attention_mask,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
mert_ssl_hidden_states,
mhubert_ssl_hidden_states,
)
def get_scheduler(self):
return FlowMatchEulerDiscreteScheduler(
num_train_timesteps=self.T,
shift=self.hparams.shift,
)
def configure_optimizers(self):
trainable_params = [
p for name, p in self.transformers.named_parameters() if p.requires_grad
]
if not trainable_params:
raise RuntimeError(
"[Pipeline.configure_optimizers] No trainable parameters found in "
"self.transformers; LoRA adapter is not active / all params frozen."
)
optimizer = torch.optim.AdamW(
params=[
{"params": trainable_params},
],
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
betas=(0.8, 0.9),
)
max_steps = self.hparams.max_steps
warmup_steps = self.hparams.warmup_steps # New hyperparameter for warmup steps
# Create a scheduler that first warms up linearly, then (optionally) decays linearly.
# If max_steps <= 0, treat it as "no step limit" and keep LR constant after warmup.
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup from 0 to learning_rate
return float(current_step) / float(max(1, warmup_steps))
# No decay if max_steps is unset / negative (epoch-based runs only)
if max_steps is None or max_steps <= 0:
return 1.0
# Linear decay from learning_rate to 0 over [warmup_steps, max_steps]
progress = float(current_step - warmup_steps) / float(
max(1, max_steps - warmup_steps)
)
return max(0.0, 1.0 - progress)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda, last_epoch=-1
)
return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
def train_dataloader(self):
self.train_dataset = Text2MusicDataset(
train=True,
train_dataset_path=self.hparams.dataset_path,
)
return DataLoader(
self.train_dataset,
shuffle=True,
num_workers=self.hparams.num_workers,
pin_memory=False,
persistent_workers=self.hparams.num_workers > 0,
collate_fn=self.train_dataset.collate_fn,
)
def get_sd3_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = self.scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def get_timestep(self, bsz, device):
if self.hparams.timestep_densities_type == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
# In practice, we sample the random variable u from a normal distribution u ∼ N (u; m, s)
# and map it through the standard logistic function
u = torch.normal(
mean=self.hparams.logit_mean,
std=self.hparams.logit_std,
size=(bsz,),
device="cpu",
)
u = torch.nn.functional.sigmoid(u)
indices = (u * self.scheduler.config.num_train_timesteps).long()
indices = torch.clamp(
indices, 0, self.scheduler.config.num_train_timesteps - 1
)
timesteps = self.scheduler.timesteps[indices].to(device)
return timesteps
def run_step(self, batch, batch_idx):
# NOTE:
# Mid-training diffusion previews (plot_step/predict_step/diffusion_process)
# are temporarily disabled because they were interfering with Lightning's
# autograd/precision contexts. If you want to re-enable them later for
# qualitative monitoring, restore the call to self.plot_step(...) below.
#
# self.plot_step(batch, batch_idx)
(
keys,
target_latents,
attention_mask,
encoder_text_hidden_states,
text_attention_mask,
speaker_embds,
lyric_token_ids,
lyric_mask,
mert_ssl_hidden_states,
mhubert_ssl_hidden_states,
) = self.preprocess(batch, train=True)
target_image = target_latents
device = target_image.device
dtype = target_image.dtype
# Step 1: Generate random noise, initialize settings
noise = torch.randn_like(target_image, device=device)
bsz = target_image.shape[0]
timesteps = self.get_timestep(bsz, device)
# Add noise according to flow matching.
sigmas = self.get_sd3_sigmas(
timesteps=timesteps,
device=device,
n_dim=target_image.ndim,
dtype=dtype,
)
noisy_image = sigmas * noise + (1.0 - sigmas) * target_image
# This is the flow-matching target for vanilla SD3.
target = target_image
# SSL constraints for CLAP and vocal_latent_channel2
all_ssl_hiden_states = None
if mert_ssl_hidden_states is not None or mhubert_ssl_hidden_states is not None:
all_ssl_hiden_states = []
if mert_ssl_hidden_states is not None:
all_ssl_hiden_states.append(mert_ssl_hidden_states)
if mhubert_ssl_hidden_states is not None:
all_ssl_hiden_states.append(mhubert_ssl_hidden_states)
# N x H -> N x c x W x H
x = noisy_image
# Step 5: Predict noise
transformer_output = self.transformers(
hidden_states=x,
attention_mask=attention_mask,
encoder_text_hidden_states=encoder_text_hidden_states,
text_attention_mask=text_attention_mask,
speaker_embeds=speaker_embds,
lyric_token_idx=lyric_token_ids,
lyric_mask=lyric_mask,
timestep=timesteps.to(device).to(dtype),
ssl_hidden_states=all_ssl_hiden_states,
)
model_pred = transformer_output.sample
proj_losses = transformer_output.proj_losses
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_image
# Compute loss. Only calculate loss where chunk_mask is 1 and there is no padding
# N x T x 64
# N x T -> N x c x W x T
mask = (
attention_mask.unsqueeze(1)
.unsqueeze(1)
.expand(-1, target_image.shape[1], target_image.shape[2], -1)
)
selected_model_pred = (model_pred * mask).reshape(bsz, -1).contiguous()
selected_target = (target * mask).reshape(bsz, -1).contiguous()
loss = F.mse_loss(selected_model_pred, selected_target, reduction="none")
loss = loss.mean(1)
loss = loss * mask.reshape(bsz, -1).mean(1)
loss = loss.mean()
# Extra safety: if something ever nukes the grad graph again, fail loudly.
if not loss.requires_grad:
logger.error(
f"[run_step] loss has no grad at global_step={self.global_step}, "
f"batch_idx={batch_idx}; something disabled grad tracking "
"around the training step."
)
raise RuntimeError(
"Loss tensor does not require grad in run_step; likely an interaction "
"with a no_grad/autocast context."
)
prefix = "train"
self.log(
f"{prefix}/denoising_loss",
loss,
on_step=True,
on_epoch=False,
prog_bar=True,
)
total_proj_loss = 0.0
for k, v in proj_losses:
self.log(
f"{prefix}/{k}_loss",
v,
on_step=True,
on_epoch=False,
prog_bar=True,
)
total_proj_loss += v
if len(proj_losses) > 0:
total_proj_loss = total_proj_loss / len(proj_losses)
loss = loss + total_proj_loss * self.ssl_coeff
self.log(
f"{prefix}/loss",
loss,
on_step=True,
on_epoch=False,
prog_bar=True,
)
# Log learning rate if scheduler exists
if self.lr_schedulers() is not None:
learning_rate = self.lr_schedulers().get_last_lr()[0]
self.log(
f"{prefix}/learning_rate",
learning_rate,
on_step=True,
on_epoch=False,
prog_bar=True,
)
return loss
# The below version of run_step includes diffusion_previews. They were nuked due to causing errors as soon as plot_step was invoked.
# def run_step(self, batch, batch_idx):
# # ------------------------------------------------------------------
# # 1) Optional eval preview under no_grad
# # ------------------------------------------------------------------
# try:
# every_plot = getattr(self.hparams, "every_plot_step", 0)
# if (
# every_plot
# and self.global_step > 0
# and (self.global_step % every_plot) == 0
# ):
# # Run plotting / diffusion fully under no_grad so it cannot
# # interfere with the actual training graph.
# with torch.no_grad():
# self.plot_step(batch, batch_idx)
# except Exception as e:
# logger.warning(
# f"[run_step] plot_step failed at global_step={self.global_step}, "
# f"batch_idx={batch_idx}: {e}"
# )
# # ------------------------------------------------------------------
# # 2) Actual training step, with gradients forced ON
# # ------------------------------------------------------------------
# with torch.set_grad_enabled(True):
# (
# keys,
# target_latents,
# attention_mask,
# encoder_text_hidden_states,
# text_attention_mask,
# speaker_embds,
# lyric_token_ids,
# lyric_mask,
# mert_ssl_hidden_states,
# mhubert_ssl_hidden_states,
# ) = self.preprocess(batch, train=True)
# target_image = target_latents
# device = target_image.device
# dtype = target_image.dtype
# # Step 1: Generate random noise, initialize settings
# noise = torch.randn_like(target_image, device=device)
# bsz = target_image.shape[0]
# timesteps = self.get_timestep(bsz, device)
# # Add noise according to flow matching.
# sigmas = self.get_sd3_sigmas(
# timesteps=timesteps,
# device=device,
# n_dim=target_image.ndim,
# dtype=dtype,
# )
# noisy_image = sigmas * noise + (1.0 - sigmas) * target_image
# # This is the flow-matching target for vanilla SD3.
# target = target_image
# # SSL constraints for CLAP and vocal_latent_channel2
# all_ssl_hiden_states = None
# if mert_ssl_hidden_states is not None or mhubert_ssl_hidden_states is not None:
# all_ssl_hiden_states = []
# if mert_ssl_hidden_states is not None:
# all_ssl_hiden_states.append(mert_ssl_hidden_states)
# if mhubert_ssl_hidden_states is not None:
# all_ssl_hiden_states.append(mhubert_ssl_hidden_states)
# # N x H -> N x c x W x H
# x = noisy_image
# # Step 5: Predict noise
# transformer_output = self.transformers(
# hidden_states=x,
# attention_mask=attention_mask,
# encoder_text_hidden_states=encoder_text_hidden_states,
# text_attention_mask=text_attention_mask,
# speaker_embeds=speaker_embds,
# lyric_token_idx=lyric_token_ids,
# lyric_mask=lyric_mask,
# timestep=timesteps.to(device).to(dtype),
# ssl_hidden_states=all_ssl_hiden_states,
# )
# model_pred = transformer_output.sample
# proj_losses = transformer_output.proj_losses
# # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# # Preconditioning of the model outputs.
# model_pred = model_pred * (-sigmas) + noisy_image
# # Compute loss. Only calculate loss where chunk_mask is 1 and there is no padding
# # N x T x 64
# # N x T -> N x c x W x T
# mask = (
# attention_mask.unsqueeze(1)
# .unsqueeze(1)
# .expand(-1, target_image.shape[1], target_image.shape[2], -1)
# )
# selected_model_pred = (model_pred * mask).reshape(bsz, -1).contiguous()
# selected_target = (target * mask).reshape(bsz, -1).contiguous()
# loss = F.mse_loss(selected_model_pred, selected_target, reduction="none")
# loss = loss.mean(1)
# loss = loss * mask.reshape(bsz, -1).mean(1)
# loss = loss.mean()
# prefix = "train"
# self.log(
# f"{prefix}/denoising_loss",
# loss,
# on_step=True,
# on_epoch=False,
# prog_bar=True,
# )
# total_proj_loss = 0.0
# for k, v in proj_losses:
# self.log(
# f"{prefix}/{k}_loss",
# v,
# on_step=True,
# on_epoch=False,
# prog_bar=True,
# )
# total_proj_loss += v
# if len(proj_losses) > 0:
# total_proj_loss = total_proj_loss / len(proj_losses)
# loss = loss + total_proj_loss * self.ssl_coeff
# self.log(
# f"{prefix}/loss",
# loss,
# on_step=True,
# on_epoch=False,
# prog_bar=True,
# )
# # Sanity check: if this ever trips again, we *know* gradients are off.
# if not loss.requires_grad:
# logger.error(
# f"[run_step] loss has no grad at global_step={self.global_step}, "
# f"batch_idx={batch_idx}; torch.is_grad_enabled()={torch.is_grad_enabled()}"
# )
# raise RuntimeError(
# "Loss tensor does not require grad in run_step; likely an "
# "interaction with a no_grad/autocast context."
# )
# # Log learning rate if scheduler exists
# if self.lr_schedulers() is not None:
# learning_rate = self.lr_schedulers().get_last_lr()[0]
# self.log(
# f"{prefix}/learning_rate",
# learning_rate,
# on_step=True,
# on_epoch=False,
# prog_bar=True,
# )