-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[TTS] Add code for training semantic codec #15524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| import os | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Dict, List, Optional | ||
| from typing import Dict, List, Optional, Tuple | ||
|
|
||
| import librosa | ||
| import soundfile as sf | ||
|
|
@@ -30,6 +30,7 @@ | |
| filter_dataset_by_duration, | ||
| get_weighted_sampler, | ||
| load_audio, | ||
| resample_batch, | ||
| sample_audio, | ||
| stack_tensors, | ||
| ) | ||
|
|
@@ -56,7 +57,7 @@ class DatasetSample: | |
| audio_dir: Path | ||
|
|
||
|
|
||
| def audio_collate_fn(batch: List[dict]): | ||
| def audio_collate_fn(batch: List[dict], resample_rates: Optional[Tuple[int]] = None): | ||
| dataset_name_list = [] | ||
| audio_filepath_list = [] | ||
| audio_list = [] | ||
|
|
@@ -73,6 +74,14 @@ def audio_collate_fn(batch: List[dict]): | |
|
|
||
| batch_audio = stack_tensors(audio_list, max_lens=[audio_max_len]) | ||
|
|
||
| if resample_rates: | ||
| batch_audio, batch_audio_len = resample_batch( | ||
| audio=batch_audio, | ||
| audio_len=batch_audio_len, | ||
| input_sample_rate=resample_rates[0], | ||
| output_sample_rate=resample_rates[1], | ||
| ) | ||
|
|
||
| batch_dict = { | ||
| "dataset_names": dataset_name_list, | ||
| "audio_filepaths": audio_filepath_list, | ||
|
|
@@ -117,7 +126,8 @@ class VocoderDataset(Dataset): | |
| Args: | ||
| dataset_meta: Dict of dataset names (string) to dataset metadata. | ||
| sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will | ||
| be resampled. | ||
| be resampled using librosa. | ||
| resample_rate: Optional sample rate to resample to, using torch-based resampling. | ||
| n_samples: Optional int, if provided then n_samples samples will be randomly sampled from the full | ||
| audio file. | ||
| weighted_sampling_steps_per_epoch: Optional int, If provided, then data will be sampled (with replacement) based on | ||
|
|
@@ -135,6 +145,7 @@ def __init__( | |
| self, | ||
| dataset_meta: Dict, | ||
| sample_rate: int, | ||
| resample_rate: Optional[int] = None, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you update the docstring to add this argument?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstring. The functionality might be a bit confusing, because the feature that is actually being added is the option to resample using batched NeMo code instead of librosa. |
||
| n_samples: Optional[int] = None, | ||
| weighted_sampling_steps_per_epoch: Optional[int] = None, | ||
| feature_processors: Optional[Dict[str, FeatureProcessor]] = None, | ||
|
|
@@ -146,6 +157,11 @@ def __init__( | |
| super().__init__() | ||
|
|
||
| self.sample_rate = sample_rate | ||
| if resample_rate and self.sample_rate != resample_rate: | ||
| self.resample_rates = [sample_rate, resample_rate] | ||
| else: | ||
| self.resample_rates = None | ||
|
|
||
| self.n_samples = n_samples | ||
| self.trunc_duration = trunc_duration | ||
| self.volume_norm = volume_norm | ||
|
|
@@ -221,7 +237,7 @@ def __getitem__(self, index): | |
| return example | ||
|
|
||
| def collate_fn(self, batch): | ||
| return audio_collate_fn(batch) | ||
| return audio_collate_fn(batch, resample_rates=self.resample_rates) | ||
|
|
||
|
|
||
| class TarredVocoderDataset(IterableDataset): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.