From c29b7afbd8f4f1fe63e65fc8721035941ad5c512 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 18 Mar 2026 16:36:44 -0700 Subject: [PATCH 1/3] [TTS] Add code for training semantic codec Signed-off-by: Ryan --- nemo/collections/tts/data/vocoder_dataset.py | 21 +- nemo/collections/tts/models/audio_codec.py | 224 +++++++++++------- .../tts/modules/audio_codec_modules.py | 153 +++++++++++- .../tts/parts/utils/tts_dataset_utils.py | 13 + .../extend_lhotse_shards_with_audio_codes.py | 11 +- 5 files changed, 320 insertions(+), 102 deletions(-) diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 84be82f47959..05c1213c3a53 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -16,7 +16,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import librosa import soundfile as sf @@ -30,6 +30,7 @@ filter_dataset_by_duration, get_weighted_sampler, load_audio, + resample_batch, sample_audio, stack_tensors, ) @@ -56,7 +57,7 @@ class DatasetSample: audio_dir: Path -def audio_collate_fn(batch: List[dict]): +def audio_collate_fn(batch: List[dict], resample_rates: Optional[Tuple[int]] = None): dataset_name_list = [] audio_filepath_list = [] audio_list = [] @@ -73,6 +74,14 @@ def audio_collate_fn(batch: List[dict]): batch_audio = stack_tensors(audio_list, max_lens=[audio_max_len]) + if resample_rates: + batch_audio, batch_audio_len = resample_batch( + audio=batch_audio, + audio_len=batch_audio_len, + input_sample_rate=resample_rates[0], + output_sample_rate=resample_rates[1], + ) + batch_dict = { "dataset_names": dataset_name_list, "audio_filepaths": audio_filepath_list, @@ -135,6 +144,7 @@ def __init__( self, dataset_meta: Dict, sample_rate: int, + resample_rate: Optional[int] = None, n_samples: Optional[int] = None, weighted_sampling_steps_per_epoch: Optional[int] = None, feature_processors: Optional[Dict[str, FeatureProcessor]] = None, @@ -146,6 +156,11 @@ def __init__( super().__init__() self.sample_rate = sample_rate + if resample_rate and self.sample_rate != resample_rate: + self.resample_rates = [sample_rate, resample_rate] + else: + self.resample_rates = None + self.n_samples = n_samples self.trunc_duration = trunc_duration self.volume_norm = volume_norm @@ -221,7 +236,7 @@ def __getitem__(self, index): return example def collate_fn(self, batch): - return audio_collate_fn(batch) + return audio_collate_fn(batch, resample_rates=self.resample_rates) class TarredVocoderDataset(IterableDataset): diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 2cdd5f0f8c9c..fae48fcb0221 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -25,8 +25,8 @@ from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from nemo.collections.audio.parts.utils.transforms import Resample, resample -from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.collections.audio.parts.utils.transforms import Resample +from nemo.collections.tts.data.vocoder_dataset import VocoderDataset from nemo.collections.tts.losses.audio_codec_loss import ( FeatureMatchingLoss, MultiResolutionMelLoss, @@ -39,6 +39,7 @@ from nemo.collections.tts.modules.common import GaussianDropout from nemo.collections.tts.parts.utils.callbacks import LoggingCallback from nemo.collections.tts.parts.utils.helpers import get_batch_size, get_num_workers +from nemo.collections.tts.parts.utils.tts_dataset_utils import resample_batch from nemo.core import ModelPT from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types.elements import ( @@ -110,7 +111,43 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.audio_decoder = instantiate(cfg.audio_decoder) # Discriminator setup - self.discriminator = instantiate(cfg.discriminator) + if cfg.get("discriminator"): + self.discriminator = instantiate(cfg.discriminator) + else: + self.discriminator = None + + # If 'semantic_codec_path' is provided, the semantic codec will be initialized from the provided path. + # It will then be registered as a submodule and automatically loaded from the 'semantic_codec' field + if cfg.get("semantic_codec"): + semantic_codec_cfg = cfg.get("semantic_codec") + semantic_codec = AudioCodecModel(cfg=semantic_codec_cfg) + elif cfg.get("semantic_codec_path"): + semantic_codec_path = cfg.get("semantic_codec_path") + semantic_codec = AudioCodecModel.restore_from(semantic_codec_path) + else: + semantic_codec = None + + if semantic_codec is not None: + self.register_nemo_submodule(name="semantic_codec", config_field="semantic_codec", model=semantic_codec) + self.semantic_codec.eval() + self.semantic_codec.freeze() + else: + self.semantic_codec = None + + # Optional config for using semantic distillation loss + self.use_slm_loss = cfg.get("use_slm_loss", False) + if self.use_slm_loss: + self.slm_encoder = instantiate(cfg.get("slm_encoder")) + self.slm_encoder.eval() + self.slm_encoder.freeze() + self.slm_decoder = instantiate(cfg.slm_decoder) + self.slm_loss_fn = torch.nn.MSELoss() + self.slm_loss_scale = cfg.get("slm_loss_scale", 1.0) + else: + self.slm_encoder = None + self.slm_decoder = None + self.slm_loss_fn = None + self.slm_loss_scale = None # Mel loss setup loss_resolutions = cfg.loss_resolutions @@ -145,6 +182,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.gen_loss_fn = instantiate(cfg.generator_loss) self.disc_loss_fn = instantiate(cfg.discriminator_loss) + self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0) + if "mmd_loss" in cfg: self.mmd_loss_fn = instantiate(cfg.mmd_loss) self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0) @@ -192,16 +231,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): orig_freq=self.sample_rate, new_freq=self.speaker_encoder.audio_config["sample_rate"] ) - # Disabled for now as it is not used in final model - self.use_asr_consitency_loss = False - self.acl_loss_scale = False - # self.use_asr_consitency_loss = cfg.get("use_asr_consitency_loss", False) - # self.acl_loss_scale = cfg.get("acl_loss_scale", False) - # if self.use_asr_consitency_loss: - # self.phoneme_asr_model = PhonemeASR(input_sr=self.sample_rate) - # self.phoneme_asr_model.freeze() - # # self.acl_loss = CrossEntropyLoss() - # print("Phoneme ASR model loaded and frozen !!") self.disc_start_epoch = cfg.get("disc_start_epoch", 0) # Log setup @@ -237,8 +266,11 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): for key in list(state_dict.keys()): if self.use_scl_loss and "speaker_encoder." in key: del state_dict[key] - if "discriminator" in key and ".slm_model.ssl_model." in key: + if "discriminator" in key and ".slm_model.slm_model." in key: + del state_dict[key] + if key.startswith("slm_encoder."): del state_dict[key] + return state_dict def load_state_dict(self, state_dict, strict=True): @@ -246,7 +278,9 @@ def load_state_dict(self, state_dict, strict=True): for key in list(state_dict.keys()): if self.use_scl_loss and "speaker_encoder." in key: del state_dict[key] - if "discriminator" in key and ".slm_model.ssl_model." in key: + if "discriminator" in key and ".slm_model.slm_model." in key: + del state_dict[key] + if key.startswith("slm_encoder."): del state_dict[key] super().load_state_dict(state_dict, strict=False) @@ -284,8 +318,19 @@ def encode_audio( Returns: Encoder output `encoded` and its length in number of frames `encoded_len` """ - audio, audio_len = self.preprocess_audio(audio=audio, audio_len=audio_len, sample_rate=sample_rate) - encoded, encoded_len = self.audio_encoder(audio=audio, audio_len=audio_len) + if not sample_rate: + sample_rate = self.sample_rate + + audio_preprocessed, audio_preprocessed_len = self.preprocess_audio( + audio=audio, audio_len=audio_len, sample_rate=sample_rate + ) + encoded, encoded_len = self.audio_encoder(audio=audio_preprocessed, audio_len=audio_preprocessed_len) + + if self.semantic_codec is not None: + semantic, _ = self.semantic_codec.encode_audio(audio=audio, audio_len=audio_len, sample_rate=sample_rate) + semantic = semantic.detach() + encoded = torch.concat([semantic, encoded], dim=1) + return encoded, encoded_len @typecheck( @@ -489,13 +534,9 @@ def pad_audio(self, audio, audio_len, samples_per_frame): def preprocess_audio(self, audio, audio_len, sample_rate): if sample_rate and sample_rate != self.sample_rate: - audio = resample(waveform=audio, orig_freq=sample_rate, new_freq=self.sample_rate) - audio_len_scaled = audio_len.long() * self.sample_rate - new_audio_len = audio_len_scaled / sample_rate - # To avoid rounding issues at lower precisions, do not call torch.ceil when the length is divisible by the sample rate - audio_len = torch.where(audio_len_scaled % sample_rate == 0, new_audio_len, torch.ceil(new_audio_len)) - audio_len = audio_len.int() - audio = mask_sequence_tensor(audio, audio_len) + audio, audio_len = resample_batch( + audio=audio, audio_len=audio_len, input_sample_rate=sample_rate, output_sample_rate=self.sample_rate + ) audio, audio_len = self.pad_audio(audio=audio, audio_len=audio_len, samples_per_frame=self.samples_per_frame) return audio, audio_len @@ -531,7 +572,14 @@ def _process_batch(self, batch): # [B, T] audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len) - return audio, audio_len, audio_gen, commit_loss, encoded + if self.training and self.use_slm_loss: + slm_emb = self.slm_encoder(audio=audio) + slm_emb_pred = self.slm_decoder(inputs=encoded) + else: + slm_emb = None + slm_emb_pred = None + + return audio, audio_len, audio_gen, commit_loss, encoded, slm_emb, slm_emb_pred @property def disc_update_prob(self) -> float: @@ -549,16 +597,20 @@ def should_update_disc(self, batch_idx) -> bool: return disc_update_step < self.disc_updates_per_period def training_step(self, batch, batch_idx): - optim_gen, optim_disc = self.optimizers() + if self.discriminator is None: + optim_gen = self.optimizers() + optim_disc = None + else: + optim_gen, optim_disc = self.optimizers() - audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch) + audio, audio_len, audio_gen, commit_loss, codes, slm_emb, slm_emb_pred = self._process_batch(batch) metrics = { "global_step": self.global_step, "lr": optim_gen.param_groups[0]['lr'], } - if self.should_update_disc(batch_idx): + if optim_disc is not None and self.should_update_disc(batch_idx): # Train discriminator disc_scores_real, disc_scores_gen, _, _ = self.discriminator( audio_real=audio, audio_gen=audio_gen.detach() @@ -599,17 +651,19 @@ def training_step(self, batch, batch_idx): metrics["g_loss_si_sdr"] = loss_si_sdr generator_losses.append(self.si_sdr_loss_scale * loss_si_sdr) - _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) + if optim_disc is not None: - if self.gen_loss_scale: - loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen) - metrics["g_loss_gen"] = loss_gen - generator_losses.append(self.gen_loss_scale * loss_gen) + _, disc_scores_gen, fmaps_real, fmaps_gen = self.discriminator(audio_real=audio, audio_gen=audio_gen) - if self.feature_loss_scale: - loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) - metrics["g_loss_feature"] = loss_feature - generator_losses.append(self.feature_loss_scale * loss_feature) + if self.gen_loss_scale: + loss_gen = self.gen_loss_fn(disc_scores_gen=disc_scores_gen) + metrics["g_loss_gen"] = loss_gen + generator_losses.append(self.gen_loss_scale * loss_gen) + + if self.feature_loss_scale: + loss_feature = self.feature_loss_fn(fmaps_real=fmaps_real, fmaps_gen=fmaps_gen) + metrics["g_loss_feature"] = loss_feature + generator_losses.append(self.feature_loss_scale * loss_feature) if self.commit_loss_scale: metrics["g_loss_commit"] = commit_loss @@ -627,6 +681,11 @@ def training_step(self, batch, batch_idx): if self.current_epoch >= self.mmd_loss_start_epoch: generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time) + if self.use_slm_loss: + loss_slm = self.slm_loss_fn(input=slm_emb_pred, target=slm_emb) + metrics["g_loss_slm"] = loss_slm + generator_losses.append(self.slm_loss_scale * loss_slm) + # compute embeddings for speaker consistency loss if self.use_scl_loss: # concate generated and GT waveforms @@ -644,19 +703,6 @@ def training_step(self, batch, batch_idx): metrics["g_loss_scl"] = loss_scl generator_losses.append(metrics["g_loss_scl"]) - if self.use_asr_consitency_loss: - # concate generated and GT waveforms - audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) - - logits, _ = self.phoneme_asr_model(audios_batch) - - logits_gt, logits_pred = torch.chunk(logits, 2, dim=0) - # labels_gt, labels_pred = torch.chunk(labels, 2, dim=0) - - loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale - metrics["g_loss_acl"] = loss_acl - generator_losses.append(metrics["g_loss_acl"]) - loss_gen_all = sum(generator_losses) optim_gen.zero_grad() @@ -672,7 +718,7 @@ def on_train_epoch_end(self): self.update_lr("epoch") def validation_step(self, batch, batch_idx): - audio, audio_len, audio_gen, _, _ = self._process_batch(batch) + audio, audio_len, audio_gen, *_ = self._process_batch(batch) loss_mel_l1, loss_mel_l2 = self.mel_loss_fn( audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len @@ -709,45 +755,34 @@ def validation_step(self, batch, batch_idx): metrics["val_loss_scl"] = loss_scl metrics["val_loss"] += metrics["val_loss_scl"] - if self.use_asr_consitency_loss: - # concate generated and GT waveforms - audios_batch = torch.cat((audio.squeeze(1), audio_gen.squeeze(1)), dim=0) - - logits, _ = self.phoneme_asr_model(audios_batch) - logits_gt, logits_pred = torch.chunk(logits, 2, dim=0) - - loss_acl = torch.nn.functional.mse_loss(logits_pred, logits_gt) * self.acl_loss_scale - metrics["val_loss_acl"] = loss_acl - metrics["val_loss"] += metrics["val_loss_acl"] - self.log_dict(metrics, on_epoch=True, sync_dist=True) - def get_dataset(self, cfg): - with open_dict(cfg): - is_sharded = cfg.dataset.pop('is_sharded', False) - + def get_dataset(self, cfg, is_sharded=False): if is_sharded: with open_dict(cfg): cfg.dataset.global_rank = self.global_rank cfg.dataset.world_size = self.world_size cfg.dataset._target_ = 'nemo.collections.tts.data.vocoder_dataset.TarredVocoderDataset' - - dataset = instantiate(cfg.dataset) + dataset = instantiate(cfg.dataset) + elif '_target_' in cfg.dataset: + dataset = instantiate(cfg.dataset) + else: + dataset = VocoderDataset(**cfg.dataset.dataset_args) sampler = dataset.get_sampler(cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) - return dataset, sampler - - def _setup_train_dataloader(self, cfg): - dataset, sampler = self.get_dataset(cfg) data_loader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, **cfg.dataloader_params ) return data_loader + def _setup_train_dataloader(self, cfg): + with open_dict(cfg): + is_sharded = cfg.dataset.pop('is_sharded', False) + + return self.get_dataset(cfg, is_sharded=is_sharded) + def _setup_test_dataloader(self, cfg): - dataset = instantiate(cfg.dataset) - data_loader = torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params) - return data_loader + return self.get_dataset(cfg) def setup_training_data(self, cfg): self._train_dl = self._setup_train_dataloader(cfg) @@ -807,20 +842,25 @@ def configure_optimizers(self): sched_config = optim_config.pop("sched", None) OmegaConf.set_struct(optim_config, True) - asr_ph_params = self.phoneme_asr_model.parameters() if self.use_asr_consitency_loss else [] se_params = self.speaker_encoder.parameters() if self.use_scl_loss else [] vq_params = self.vector_quantizer.parameters() if self.vector_quantizer else [] gen_params = itertools.chain( - self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params, asr_ph_params, se_params + self.audio_encoder.parameters(), self.audio_decoder.parameters(), vq_params, se_params ) optim_g = instantiate(optim_config, params=gen_params) - disc_params = self.discriminator.parameters() - optim_d = instantiate(optim_config, params=disc_params) + if self.discriminator is None: + optim_d = None + else: + disc_params = self.discriminator.parameters() + optim_d = instantiate(optim_config, params=disc_params) if sched_config is None: logging.debug('Scheduler is not used') - return [optim_g, optim_d] + if optim_d is None: + return optim_g + else: + return optim_g, optim_d logging.debug('Setting up schedulers') OmegaConf.set_struct(sched_config, False) @@ -831,20 +871,26 @@ def configure_optimizers(self): optimizer=optim_g, scheduler_config=sched_config, train_dataloader=self._train_dl ) - scheduler_d = prepare_lr_scheduler( - optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl - ) - self.lr_schedule_interval = scheduler_g["interval"] - return [optim_g, optim_d], [scheduler_g, scheduler_d] + if optim_d is None: + return [optim_g], [scheduler_g] + else: + scheduler_d = prepare_lr_scheduler( + optimizer=optim_d, scheduler_config=sched_config, train_dataloader=self._train_dl + ) + return [optim_g, optim_d], [scheduler_g, scheduler_d] def update_lr(self, interval="step"): schedulers = self.lr_schedulers() - if schedulers is not None and self.lr_schedule_interval == interval: - sch1, sch2 = schedulers - sch1.step() - sch2.step() + if schedulers is None or self.lr_schedule_interval != interval: + return + + if self.discriminator is None: + schedulers.step() + else: + schedulers[0].step() + schedulers[1].step() def configure_callbacks(self): if not self.log_config: diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index d34722249167..73f271121b96 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -22,7 +22,7 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import AutoModel +from transformers import AutoFeatureExtractor, AutoModel, Wav2Vec2BertModel from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor from nemo.collections.audio.parts.utils.transforms import MelSpectrogram, Resample @@ -177,6 +177,157 @@ def forward(self, audio_real, audio_gen): return [y_d_r.unsqueeze(1)], [y_d_g.unsqueeze(1)], [fmap_r], [fmap_g] +class SLMEncoder(NeuralModule): + """Encoder wrapping a speech language model (SLM) which produces semantic embeddings for use in semantic distillation. + + Args: + slm_model_name: Name of Hugging Face model. + slm_sr: Sample rate SLM model requires for input. + input_sr: Sampling rate of audio that will be input to this encoder. + hidden_layer: Index of hidden layer to extract embeddings from. + Defaults to 16, which for research suggests is effective for w2v-bert and TTS. + padding: Number of audio samples to pad before encoding to ensure output has a frame rate compatible with the audio codec. + scaling_factor: Constant factor to scale output embedding by. + """ + + def __init__( + self, + slm_model_name="facebook/w2v-bert-2.0", + slm_sr=16000, + input_sr=22050, + hidden_layer=16, + padding=80, + scaling_factor=5.0, + ): + super().__init__() + + self.slm_sr = slm_sr + if input_sr == self.slm_sr: + self.resample = None + else: + self.resample = Resample(orig_freq=input_sr, new_freq=self.slm_sr) + + self.feature_extractor = AutoFeatureExtractor.from_pretrained(slm_model_name) + self.semantic_model = Wav2Vec2BertModel.from_pretrained(slm_model_name, output_hidden_states=True) + self.semantic_model.eval() + + self.hidden_layer = hidden_layer + self.padding = padding + self.scaling_factor = scaling_factor + + @property + def input_types(self): + return { + "audio": NeuralType(('B', 'T'), AudioSignal()), + } + + @property + def output_types(self): + return { + "slm_embeddings": [NeuralType(('B', 'D', 'T'), VoidType())], + } + + @typecheck() + def forward(self, audio): + if self.resample is not None: + audio = self.resample(audio) + + audio = torch.nn.functional.pad(audio, (0, self.padding)) + feats = self.feature_extractor(audio.cpu(), sampling_rate=self.slm_sr, return_tensors="pt").data[ + 'input_features' + ] + feats = feats.to(audio.device) + + with torch.no_grad(): + out = self.semantic_model(feats) + slm_emb = out.hidden_states[self.hidden_layer] / self.scaling_factor + + slm_emb = rearrange(slm_emb, 'B T D -> B D T') + + return slm_emb + + +class SLMDecoder(NeuralModule): + """Decoder for predicting SLM embeddings for semantic distillation. This decoder uses transposed convolutions to upsample from + the codecs frame rate to the frame rate of the SLM model. + + Args: + in_channels: Input dimension of quantized codec encoding. + hidden_dim: Hidden dimension that input will be projected to. + out_channels: Dimension of decoder embedding + up_sample_rate: Rate to up sample by to match SLM frame rate. + kernel_size: Kernel size of convolutions. + padding_mode: Padding used with convolutions. + activation: Activation to use in between convolutions + """ + + def __init__( + self, + in_channels: int, + hidden_dim: int, + out_channels: int, + up_sample_rate: int = 1, + kernel_size: int = 3, + padding_mode: str = "replicate", + activation: str = "lrelu", + ): + super().__init__() + padding = get_padding(kernel_size=kernel_size) + self.activation = CodecActivation(activation=activation) + self.input_layer = nn.Conv1d( + in_channels=in_channels, + out_channels=hidden_dim, + kernel_size=kernel_size, + padding=padding, + padding_mode=padding_mode, + ) + self.output_layer = nn.Conv1d( + in_channels=hidden_dim, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + padding_mode=padding_mode, + ) + + if up_sample_rate > 1: + up_kernel_size = 2 * up_sample_rate + up_padding, output_padding = get_up_sample_padding(up_kernel_size, up_sample_rate) + self.upsample_layer = nn.Sequential( + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=up_kernel_size, + stride=up_sample_rate, + padding=up_padding, + output_padding=output_padding, + ), + self.activation, + ) + else: + self.upsample_layer = nn.Identity() + + @property + def input_types(self): + return { + "inputs": NeuralType(('B', 'D', 'T'), VoidType()), + } + + @property + def output_types(self): + return { + "output": NeuralType(('B', 'C', 'T'), VoidType()), + } + + @typecheck() + def forward(self, inputs): + out = self.input_layer(inputs) + out = self.activation(out) + out = self.upsample_layer(out) + out = self.activation(out) + out = self.output_layer(out) + return out + + # Torch version of transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm def zero_mean_unit_var_norm(input_values): """ diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index e8d5416d7598..f39b42a0749f 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -28,6 +28,8 @@ from torch.special import gammaln from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.audio.parts.utils.transforms import resample +from nemo.collections.common.parts.utils import mask_sequence_tensor def get_abs_rel_paths(input_path: Path, base_path: Path) -> Tuple[Path, Path]: @@ -756,3 +758,14 @@ def chunk_text_for_inference( tokens_len = tokens_tensor.shape[0] return [tokens_tensor], [tokens_len], [text] + + +def resample_batch(audio, audio_len, input_sample_rate, output_sample_rate): + audio = resample(waveform=audio, orig_freq=input_sample_rate, new_freq=output_sample_rate) + audio_len_scaled = audio_len.long() * output_sample_rate + new_audio_len = audio_len_scaled / input_sample_rate + # To avoid rounding issues at lower precisions, do not call torch.ceil when the length is divisible by the sample rate + audio_len = torch.where(audio_len_scaled % input_sample_rate == 0, new_audio_len, torch.ceil(new_audio_len)) + audio_len = audio_len.int() + audio = mask_sequence_tensor(audio, audio_len) + return audio, audio_len \ No newline at end of file diff --git a/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py index 28bca9233256..6516ae195c51 100644 --- a/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py +++ b/scripts/magpietts/extend_lhotse_shards_with_audio_codes.py @@ -407,20 +407,13 @@ def forward( target_audio_lens = target_audio_lens.to(self.device) context_audios = context_audios.to(self.device) context_audio_lens = context_audio_lens.to(self.device) - # NOTE: we avoided directly calling `self.codec_model.encode()` because it pads audios again. with torch.inference_mode(): - target_audios_encoded, target_audios_encoded_len = self.codec_model.audio_encoder( + target_tokens, target_audios_encoded_len = self.codec_model.encode( audio=target_audios, audio_len=target_audio_lens ) - target_tokens = self.codec_model.quantize( - encoded=target_audios_encoded, encoded_len=target_audios_encoded_len - ) - context_audios_encoded, context_audios_encoded_len = self.codec_model.audio_encoder( + context_tokens, context_audios_encoded_len = self.codec_model.encode( audio=context_audios, audio_len=context_audio_lens ) - context_tokens = self.codec_model.quantize( - encoded=context_audios_encoded, encoded_len=context_audios_encoded_len - ) return { "target_codes": target_tokens.to(dtype=torch.uint16, device="cpu"), "target_codes_lengths": target_audios_encoded_len.to(device="cpu"), From 0b037e1d215acebfdbb7166bd5a5ae79f28658de Mon Sep 17 00:00:00 2001 From: rlangman Date: Thu, 19 Mar 2026 22:59:03 +0000 Subject: [PATCH 2/3] Apply isort and black reformatting Signed-off-by: rlangman --- nemo/collections/tts/parts/utils/tts_dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index f39b42a0749f..e8a90f5cd860 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -768,4 +768,4 @@ def resample_batch(audio, audio_len, input_sample_rate, output_sample_rate): audio_len = torch.where(audio_len_scaled % input_sample_rate == 0, new_audio_len, torch.ceil(new_audio_len)) audio_len = audio_len.int() audio = mask_sequence_tensor(audio, audio_len) - return audio, audio_len \ No newline at end of file + return audio, audio_len From ca1c5ddd376a68c39632bc55bfb6556080db5121 Mon Sep 17 00:00:00 2001 From: rlangman Date: Thu, 19 Mar 2026 22:59:03 +0000 Subject: [PATCH 3/3] Rename slm decoder to predictor Signed-off-by: Ryan --- nemo/collections/tts/data/vocoder_dataset.py | 3 +- nemo/collections/tts/models/audio_codec.py | 28 ++++++++++--------- .../tts/modules/audio_codec_modules.py | 6 ++-- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/nemo/collections/tts/data/vocoder_dataset.py b/nemo/collections/tts/data/vocoder_dataset.py index 05c1213c3a53..ac04a01c30c8 100644 --- a/nemo/collections/tts/data/vocoder_dataset.py +++ b/nemo/collections/tts/data/vocoder_dataset.py @@ -126,7 +126,8 @@ class VocoderDataset(Dataset): Args: dataset_meta: Dict of dataset names (string) to dataset metadata. sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will - be resampled. + be resampled using librosa. + resample_rate: Optional sample rate to resample to, using torch-based resampling. n_samples: Optional int, if provided then n_samples samples will be randomly sampled from the full audio file. weighted_sampling_steps_per_epoch: Optional int, If provided, then data will be sampled (with replacement) based on diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index fae48fcb0221..3589cb31a6c2 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -16,7 +16,7 @@ from contextlib import nullcontext from math import ceil from pathlib import Path -from typing import List, Tuple +from typing import Iterable, List, Tuple import torch import torch.nn.functional as F @@ -128,9 +128,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): semantic_codec = None if semantic_codec is not None: + semantic_codec.eval() + semantic_codec.freeze() self.register_nemo_submodule(name="semantic_codec", config_field="semantic_codec", model=semantic_codec) - self.semantic_codec.eval() - self.semantic_codec.freeze() else: self.semantic_codec = None @@ -140,12 +140,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.slm_encoder = instantiate(cfg.get("slm_encoder")) self.slm_encoder.eval() self.slm_encoder.freeze() - self.slm_decoder = instantiate(cfg.slm_decoder) + self.slm_predictor = instantiate(cfg.slm_predictor) self.slm_loss_fn = torch.nn.MSELoss() self.slm_loss_scale = cfg.get("slm_loss_scale", 1.0) else: self.slm_encoder = None - self.slm_decoder = None + self.slm_predictor = None self.slm_loss_fn = None self.slm_loss_scale = None @@ -261,7 +261,7 @@ def codebook_size(self): def state_dict(self, destination=None, prefix='', keep_vars=False): if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} - # Don't save the speaker verification and codec model in the state dict + # Avoid saving weights of frozen pretrained models state_dict = super().state_dict(destination, prefix, keep_vars) for key in list(state_dict.keys()): if self.use_scl_loss and "speaker_encoder." in key: @@ -274,7 +274,7 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): return state_dict def load_state_dict(self, state_dict, strict=True): - # Override to load all the keys except .speaker_encoder. and WavLM model + # Avoid loading weights of frozen pretrained models for key in list(state_dict.keys()): if self.use_scl_loss and "speaker_encoder." in key: del state_dict[key] @@ -327,8 +327,10 @@ def encode_audio( encoded, encoded_len = self.audio_encoder(audio=audio_preprocessed, audio_len=audio_preprocessed_len) if self.semantic_codec is not None: - semantic, _ = self.semantic_codec.encode_audio(audio=audio, audio_len=audio_len, sample_rate=sample_rate) - semantic = semantic.detach() + with torch.no_grad(): + semantic, _ = self.semantic_codec.encode_audio( + audio=audio, audio_len=audio_len, sample_rate=sample_rate + ) encoded = torch.concat([semantic, encoded], dim=1) return encoded, encoded_len @@ -574,7 +576,7 @@ def _process_batch(self, batch): if self.training and self.use_slm_loss: slm_emb = self.slm_encoder(audio=audio) - slm_emb_pred = self.slm_decoder(inputs=encoded) + slm_emb_pred = self.slm_predictor(inputs=encoded) else: slm_emb = None slm_emb_pred = None @@ -886,11 +888,11 @@ def update_lr(self, interval="step"): if schedulers is None or self.lr_schedule_interval != interval: return - if self.discriminator is None: + if not isinstance(schedulers, Iterable): schedulers.step() else: - schedulers[0].step() - schedulers[1].step() + for sch in schedulers: + sch.step() def configure_callbacks(self): if not self.log_config: diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 73f271121b96..56cb6721401e 100755 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -187,7 +187,7 @@ class SLMEncoder(NeuralModule): hidden_layer: Index of hidden layer to extract embeddings from. Defaults to 16, which for research suggests is effective for w2v-bert and TTS. padding: Number of audio samples to pad before encoding to ensure output has a frame rate compatible with the audio codec. - scaling_factor: Constant factor to scale output embedding by. + scaling_factor: Constant factor to divide output embedding by. Defaults to 5 to produce embeddings with values in [-1, 1]. """ def __init__( @@ -247,8 +247,8 @@ def forward(self, audio): return slm_emb -class SLMDecoder(NeuralModule): - """Decoder for predicting SLM embeddings for semantic distillation. This decoder uses transposed convolutions to upsample from +class SLMPredictor(NeuralModule): + """Module for predicting SLM embeddings for semantic distillation. This decoder uses transposed convolutions to upsample from the codecs frame rate to the frame rate of the SLM model. Args: