Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,24 @@ def get_code_switched_dataset(
tokenizer: Optional['TokenizerSpec'] = None,
augmentor: Optional['AudioAugmentor'] = None,
) -> CodeSwitchedDataset:
"""
Create a code-switched (multilingual) dataset by combining multiple monolingual datasets.

Each manifest filepath is loaded as a separate dataset (char-based, BPE-based, or tarred),
then wrapped in a ``CodeSwitchedDataset`` that mixes segments from different languages.

Args:
config: Dataset configuration dict. Must contain ``manifest_filepath`` and ``code_switched``
parameter groups. Optionally contains ``tarred_audio_filepaths``.
shuffle_n: Buffer size for shuffling tarred datasets.
global_rank: Global rank of the current process in distributed training.
world_size: Total number of distributed processes.
tokenizer: Optional tokenizer for BPE-based datasets. If None, char-based datasets are used.
augmentor: Optional audio augmentor applied to the combined dataset.

Returns:
A ``CodeSwitchedDataset`` instance that produces mixed-language audio samples.
"""
if 'manifest_filepath' not in config:
raise ValueError("`manifest_filepath` must be provided in the dataset config if `is_code_switched=True`")
if 'code_switched' not in config:
Expand Down Expand Up @@ -910,6 +927,21 @@ def close_output_file(self):


def convert_to_config_list(initial_list):
"""
Normalize manifest or tarred audio file paths into a ``ListConfig`` of lists.

Handles string inputs (comma-separated), flat lists, and nested lists. Ensures the output
is always a ``ListConfig`` of ``ListConfig`` entries, which is required for bucketed training.

Args:
initial_list: A string, list, or ``ListConfig`` of file paths. Strings are split by commas.

Returns:
A ``ListConfig`` of ``ListConfig`` entries, where each inner list corresponds to one bucket.

Raises:
ValueError: If the input is empty or contains mixed types.
"""
if type(initial_list) is str:
initial_list = initial_list.split(",")
if initial_list is None or initial_list == []:
Expand All @@ -928,6 +960,27 @@ def convert_to_config_list(initial_list):


def get_chain_dataset(datasets, ds_config, rank=0):
"""
Chain multiple bucketed datasets using the specified bucketing strategy.

When multiple datasets are provided (one per bucket), this function optionally wraps each
in a ``BucketingDataset`` with adaptive batch sizes and chains them together using the
configured bucketing strategy.

Args:
datasets: List of datasets, one per bucket.
ds_config: Dataset configuration dict containing ``bucketing_batch_size``,
``batch_size``, and ``bucketing_strategy`` keys.
rank: Process rank for seeding randomization in ``fully_randomized`` strategy.

Returns:
A single dataset if only one bucket, otherwise a ``ChainDataset`` or
``RandomizedChainDataset`` depending on the bucketing strategy.

Raises:
ValueError: If ``bucketing_strategy`` is not one of ``fixed_order``,
``synced_randomized``, or ``fully_randomized``.
"""
if len(datasets) > 1:
if ds_config.get('bucketing_batch_size', None) is not None:
bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets))
Expand Down Expand Up @@ -959,6 +1012,26 @@ def get_chain_dataset(datasets, ds_config, rank=0):


def calc_bucketing_batch_sizes(ds_config, datasets_len):
"""
Calculate per-bucket batch sizes for adaptive bucketing.

Supports two modes: linear scaling (integer ``bucketing_batch_size``) where shorter-duration
buckets get proportionally larger batches, and explicit assignment (list of integers) where
each bucket's batch size is specified directly.

Args:
ds_config: Dataset configuration dict containing ``bucketing_batch_size`` (int or list),
``batch_size`` (must be 1 when bucketing is enabled), and optionally
``bucketing_weights`` for upsampled buckets.
datasets_len: Number of buckets (datasets).

Returns:
List of batch sizes, one per bucket.

Raises:
ValueError: If ``batch_size`` is not 1, ``bucketing_batch_size`` is not an int or list,
or the resulting batch size list length doesn't match the number of buckets.
"""
bucketing_batch_size = ds_config['bucketing_batch_size']
bucketing_weights = ds_config.get('bucketing_weights', None) # To adjust for upsampled buckets

Expand Down
Loading