Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions nemo/collections/tts/data/vocoder_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import librosa
import soundfile as sf
Expand All @@ -30,6 +30,7 @@
filter_dataset_by_duration,
get_weighted_sampler,
load_audio,
resample_batch,
sample_audio,
stack_tensors,
)
Expand All @@ -56,7 +57,7 @@ class DatasetSample:
audio_dir: Path


def audio_collate_fn(batch: List[dict]):
def audio_collate_fn(batch: List[dict], resample_rates: Optional[Tuple[int]] = None):
dataset_name_list = []
audio_filepath_list = []
audio_list = []
Expand All @@ -73,6 +74,14 @@ def audio_collate_fn(batch: List[dict]):

batch_audio = stack_tensors(audio_list, max_lens=[audio_max_len])

if resample_rates:
batch_audio, batch_audio_len = resample_batch(
audio=batch_audio,
audio_len=batch_audio_len,
input_sample_rate=resample_rates[0],
output_sample_rate=resample_rates[1],
)

batch_dict = {
"dataset_names": dataset_name_list,
"audio_filepaths": audio_filepath_list,
Expand Down Expand Up @@ -117,7 +126,8 @@ class VocoderDataset(Dataset):
Args:
dataset_meta: Dict of dataset names (string) to dataset metadata.
sample_rate: Sample rate to load audio as. If the audio is stored at a different sample rate, then it will
be resampled.
be resampled using librosa.
resample_rate: Optional sample rate to resample to, using torch-based resampling.
n_samples: Optional int, if provided then n_samples samples will be randomly sampled from the full
audio file.
weighted_sampling_steps_per_epoch: Optional int, If provided, then data will be sampled (with replacement) based on
Expand All @@ -135,6 +145,7 @@ def __init__(
self,
dataset_meta: Dict,
sample_rate: int,
resample_rate: Optional[int] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you update the docstring to add this argument?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstring. The functionality might be a bit confusing, because the feature that is actually being added is the option to resample using batched NeMo code instead of librosa.

n_samples: Optional[int] = None,
weighted_sampling_steps_per_epoch: Optional[int] = None,
feature_processors: Optional[Dict[str, FeatureProcessor]] = None,
Expand All @@ -146,6 +157,11 @@ def __init__(
super().__init__()

self.sample_rate = sample_rate
if resample_rate and self.sample_rate != resample_rate:
self.resample_rates = [sample_rate, resample_rate]
else:
self.resample_rates = None

self.n_samples = n_samples
self.trunc_duration = trunc_duration
self.volume_norm = volume_norm
Expand Down Expand Up @@ -221,7 +237,7 @@ def __getitem__(self, index):
return example

def collate_fn(self, batch):
return audio_collate_fn(batch)
return audio_collate_fn(batch, resample_rates=self.resample_rates)


class TarredVocoderDataset(IterableDataset):
Expand Down
Loading
Loading