From bbe1693555edfdf2654c657f8601aef5bf003102 Mon Sep 17 00:00:00 2001 From: stanley1208 Date: Sat, 28 Mar 2026 22:31:07 -0700 Subject: [PATCH] fix: add missing docstrings to data pipeline utility functions Signed-off-by: stanley1208 Made-with: Cursor --- .../asr/data/audio_to_text_dataset.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index afb7a86a5f0e..bc114b3af3f8 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -413,7 +413,24 @@ def get_code_switched_dataset( tokenizer: Optional['TokenizerSpec'] = None, augmentor: Optional['AudioAugmentor'] = None, ) -> CodeSwitchedDataset: + """ + Create a code-switched (multilingual) dataset by combining multiple monolingual datasets. + + Each manifest filepath is loaded as a separate dataset (char-based, BPE-based, or tarred), + then wrapped in a ``CodeSwitchedDataset`` that mixes segments from different languages. + + Args: + config: Dataset configuration dict. Must contain ``manifest_filepath`` and ``code_switched`` + parameter groups. Optionally contains ``tarred_audio_filepaths``. + shuffle_n: Buffer size for shuffling tarred datasets. + global_rank: Global rank of the current process in distributed training. + world_size: Total number of distributed processes. + tokenizer: Optional tokenizer for BPE-based datasets. If None, char-based datasets are used. + augmentor: Optional audio augmentor applied to the combined dataset. + Returns: + A ``CodeSwitchedDataset`` instance that produces mixed-language audio samples. + """ if 'manifest_filepath' not in config: raise ValueError("`manifest_filepath` must be provided in the dataset config if `is_code_switched=True`") if 'code_switched' not in config: @@ -910,6 +927,21 @@ def close_output_file(self): def convert_to_config_list(initial_list): + """ + Normalize manifest or tarred audio file paths into a ``ListConfig`` of lists. + + Handles string inputs (comma-separated), flat lists, and nested lists. Ensures the output + is always a ``ListConfig`` of ``ListConfig`` entries, which is required for bucketed training. + + Args: + initial_list: A string, list, or ``ListConfig`` of file paths. Strings are split by commas. + + Returns: + A ``ListConfig`` of ``ListConfig`` entries, where each inner list corresponds to one bucket. + + Raises: + ValueError: If the input is empty or contains mixed types. + """ if type(initial_list) is str: initial_list = initial_list.split(",") if initial_list is None or initial_list == []: @@ -928,6 +960,27 @@ def convert_to_config_list(initial_list): def get_chain_dataset(datasets, ds_config, rank=0): + """ + Chain multiple bucketed datasets using the specified bucketing strategy. + + When multiple datasets are provided (one per bucket), this function optionally wraps each + in a ``BucketingDataset`` with adaptive batch sizes and chains them together using the + configured bucketing strategy. + + Args: + datasets: List of datasets, one per bucket. + ds_config: Dataset configuration dict containing ``bucketing_batch_size``, + ``batch_size``, and ``bucketing_strategy`` keys. + rank: Process rank for seeding randomization in ``fully_randomized`` strategy. + + Returns: + A single dataset if only one bucket, otherwise a ``ChainDataset`` or + ``RandomizedChainDataset`` depending on the bucketing strategy. + + Raises: + ValueError: If ``bucketing_strategy`` is not one of ``fixed_order``, + ``synced_randomized``, or ``fully_randomized``. + """ if len(datasets) > 1: if ds_config.get('bucketing_batch_size', None) is not None: bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets)) @@ -959,6 +1012,26 @@ def get_chain_dataset(datasets, ds_config, rank=0): def calc_bucketing_batch_sizes(ds_config, datasets_len): + """ + Calculate per-bucket batch sizes for adaptive bucketing. + + Supports two modes: linear scaling (integer ``bucketing_batch_size``) where shorter-duration + buckets get proportionally larger batches, and explicit assignment (list of integers) where + each bucket's batch size is specified directly. + + Args: + ds_config: Dataset configuration dict containing ``bucketing_batch_size`` (int or list), + ``batch_size`` (must be 1 when bucketing is enabled), and optionally + ``bucketing_weights`` for upsampled buckets. + datasets_len: Number of buckets (datasets). + + Returns: + List of batch sizes, one per bucket. + + Raises: + ValueError: If ``batch_size`` is not 1, ``bucketing_batch_size`` is not an int or list, + or the resulting batch size list length doesn't match the number of buckets. + """ bucketing_batch_size = ds_config['bucketing_batch_size'] bucketing_weights = ds_config.get('bucketing_weights', None) # To adjust for upsampled buckets