From d2457dda688189f1f091e0c8064f6fb187f2fac1 Mon Sep 17 00:00:00 2001 From: subhankar-ghosh Date: Tue, 24 Mar 2026 17:04:08 -0700 Subject: [PATCH 01/10] Add MagpieTTS finetuning docs Signed-off-by: subhankar-ghosh --- docs/source/tts/intro.rst | 1 + docs/source/tts/magpietts-finetuning.rst | 240 +++++++++++++++++++++++ docs/source/tts/models.rst | 2 +- 3 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 docs/source/tts/magpietts-finetuning.rst diff --git a/docs/source/tts/intro.rst b/docs/source/tts/intro.rst index 2dab390cf0e2..bbae7a90f0d4 100644 --- a/docs/source/tts/intro.rst +++ b/docs/source/tts/intro.rst @@ -17,6 +17,7 @@ We will illustrate details in the following sections. configs g2p magpietts + magpietts-finetuning magpietts-po magpietts-longform diff --git a/docs/source/tts/magpietts-finetuning.rst b/docs/source/tts/magpietts-finetuning.rst new file mode 100644 index 000000000000..35e7e9844fbe --- /dev/null +++ b/docs/source/tts/magpietts-finetuning.rst @@ -0,0 +1,240 @@ +.. _magpie-tts-finetuning: + +====================== +Magpie-TTS Finetuning +====================== + +Finetuning a pretrained Magpie-TTS checkpoint lets you adapt the model to new voices or new languages without training from scratch. The pretrained model has already learned general speech patterns, prosody, and acoustic modeling, so finetuning requires far less data and compute than pretraining. This guide covers two common finetuning scenarios: + +- **Adding new speakers in an existing language** — adapt the model to speak in voices not seen during pretraining, using a small dataset of target-speaker audio. +- **Adding a new language** — extend the model to synthesize speech in a language absent from the pretraining data, using a multilingual dataset configuration. + +For preference optimization (DPO/GRPO) on top of a finetuned checkpoint, see :doc:`Magpie-TTS Preference Optimization `. + + +Prerequisites +############# + +Before finetuning, you will need: + +- A pretrained Magpie-TTS checkpoint (``pretrained.ckpt`` or ``pretrained.nemo``). Public checkpoints (``https://huggingface.co/nvidia/magpie_tts_multilingual_357m``) are available on Hugging Face. +- The audio codec model (``https://huggingface.co/nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps``), available on Hugging Face alongside the TTS checkpoint. +- A prepared dataset. For faster finetuning audio codec tokens must be pre-extracted from your audio files. See the *Dataset Preparation* section below. +- NeMo installed from source or via the NeMo container. See the `NeMo GitHub page `_ for installation instructions. + + +Dataset Preparation +------------------- + +For faster finetuning, Magpie-TTS audio codec tokens can be pre-computed and stored alongside each audio file. Run the codec model over your audio dataset to generate and cache the audio codes before launching any training job: + +.. code-block:: bash + + python scripts/magpietts/extract_audio_codes.py \ + --manifest_path /path/to/your_manifest.json \ + --audio_dir /path/to/audio \ + --codecmodel_path nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps \ + --output_manifest /path/to/your_manifest_withAudioCodes.json + +Each manifest entry should be a JSON line with at minimum: + +.. code-block:: json + + { + "audio_filepath": "relative/path/to/audio.wav", + "text": "transcript of the utterance", + "duration": 5.2, + "context_audio_filepath": "relative/path/to/context.wav", + "context_text": "transcript of the context audio" + } + +The ``context_audio_filepath`` is the reference audio that the model uses for voice cloning during training. It should come from the same speaker as ``audio_filepath``. A minimum context duration of 3 seconds and a speaker similarity score of at least 0.6 (measured by TitaNet) are recommended for best results. + +For Lhotse-style dataset loading (used in the English SFT scenario), manifest entries are organized into a YAML ``input_cfg`` file instead of being passed directly. See the Lhotse configuration examples below. + + +.. _magpie-tts-new-speaker: + +Adding New Speakers in an Existing Language +########################################### + +This scenario adapts a pretrained checkpoint to a set of target speakers in a specific language already present in the pretrained checkpoint. The model already knows the language; you are teaching it new voice characteristics. You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. + +Key training choices for speaker finetuning: + +- **Low learning rate** (``5e-6``): the pretrained model is already well-converged. A high LR will destroy the learned representations. +- **Disable alignment prior** (``alignment_loss_scale=0.0``, ``prior_scaling_factor=null``): the alignment prior is beneficial during pretraining to enforce monotonicity, but during finetuning it can over-constrain adaptation and hurt quality. +- **Context audio filtering**: ensure each training sample has a high-quality context audio from the same speaker. Use ``min_context_speaker_similarity: 0.6`` in your Lhotse manifest to filter low-quality pairs. + +The dataset uses Lhotse's ``input_cfg`` YAML format, which supports bucketed batching by duration for efficient GPU utilization: + +.. code-block:: yaml + + # train_input_cfg.yaml + - input_path: /path/to/your_dataset_shards/ + type: nemo_tarred + weight: 1.0 + +.. code-block:: bash + + python examples/tts/magpietts.py \ + --config-path=examples/tts/conf/magpietts \ + --config-name=magpietts_en_v2_lhotse \ + +init_from_ptl_ckpt=/path/to/pretrained.ckpt \ + +exp_manager.explicit_log_dir=/path/to/output \ + model.codecmodel_path=nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps \ + model.train_ds.dataset.input_cfg=/path/to/train_input_cfg.yaml \ + model.train_ds.dataset.batch_duration=500 \ + "+model.train_ds.dataset.bucket_duration_bins=[4.96,5.92,6.8,7.6,8.4,9.2,10.0,10.72,11.46,12.24,13.07,13.92,14.82,15.79,16.8,17.92,18.96,19.6,19.92]" \ + model.validation_ds.dataset.input_cfg=/path/to/val_input_cfg.yaml \ + model.validation_ds.dataset.batch_duration=300 \ + model.optim.lr=5e-6 \ + ~model.optim.sched \ + model.alignment_loss_scale=0.0 \ + model.prior_scaling_factor=null \ + trainer.max_steps=15000 \ + trainer.precision=32 \ + trainer.devices=8 \ + trainer.num_nodes=1 + +The ``+init_from_ptl_ckpt`` flag loads the pretrained checkpoint weights before training begins. The ``+`` prefix is required because this key is not present in the base config. + +``~model.optim.sched`` removes the learning rate schedule so the LR stays constant throughout finetuning. For short finetuning runs a fixed low LR is more stable than a decaying schedule. + +``trainer.precision=32`` is recommended for finetuning stability. Mixed precision (``bf16`` or ``16``) can cause loss instability when adapting to small datasets. + + +.. _magpie-tts-new-language: + +Adding a New Language +##################### + +This scenario extends the model to synthesize speech in one or more languages not present in the pretraining data. The multilingual finetuning config uses the non-Lhotse ``train_ds_meta`` dataset format, which is better suited for combining multiple language-specific manifests with per-language sample weights. You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. + +Key differences from the English SFT scenario: + +- **Byte-level tokenizer** (``google/byt5-small``): when adding languages whose characters fall outside the pretraining vocabulary, use a byte-level tokenizer. This avoids out-of-vocabulary tokens for new scripts or phoneme sets and lets the model learn new language representations without modifying the vocabulary. +- **Per-language dataset entries** (``train_ds_meta.``): each language is registered as a separate entry with its own manifest path, audio directory, and ``sample_weight``. This makes it straightforward to control the language balance during training. +- **Sample weight upsampling**: languages with less data can be upsampled via ``sample_weight``. For instance, a language with only 30 minutes of training data might use ``sample_weight=10.0`` alongside a resource-rich language at ``sample_weight=1.0`` to prevent the larger language from dominating training. + +Dataset manifest entries for multilingual finetuning use the IPA phoneme representation. Ensure your manifests include IPA-transcribed text fields. Audio codes must be pre-extracted as described in the *Dataset Preparation* section. + +.. code-block:: bash + + python examples/tts/magpietts.py \ + --config-name=magpietts_multilingual_v1 \ + +init_from_ptl_ckpt=/path/to/pretrained.ckpt \ + exp_manager.exp_dir=/path/to/output \ + +model.text_tokenizers.your_language_chartokenizer._target_=AutoTokenizer \ + +model.text_tokenizers.your_language_chartokenizer.pretrained_model="google/byt5-small" \ + +train_ds_meta.your_language.manifest_path=/path/to/your_lang_train.json \ + +train_ds_meta.your_language.audio_dir=/path/to/your_lang_audio \ + +train_ds_meta.your_language.feature_dir=/path/to/your_lang_audio \ + +train_ds_meta.your_language.sample_weight=1.0 \ + "+train_ds_meta.your_language.tokenizer_names=[your_language_chartokenizer]" \ + +val_ds_meta.your_language_dev.manifest_path=/path/to/your_lang_val.json \ + +val_ds_meta.your_language_dev.audio_dir=/path/to/your_lang_audio \ + +val_ds_meta.your_language_dev.feature_dir=/path/to/your_lang_audio \ + +val_ds_meta.your_language_dev.sample_weight=1.0 \ + "+val_ds_meta.your_language_dev.tokenizer_names=[your_language_chartokenizer]" \ + model.codecmodel_path=nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps \ + model.context_duration_min=5.0 \ + model.context_duration_max=5.0 \ + model.alignment_loss_scale=0.0 \ + model.prior_scaling_factor=null \ + model.optim.lr=1e-5 \ + ~model.optim.sched \ + model.load_cached_codes_if_available=true \ + trainer.precision=32 \ + trainer.devices=8 \ + trainer.num_nodes=1 \ + max_epochs=500 + + +Mixing Multiple Languages +-------------------------- + +To train on several languages simultaneously, add one ``train_ds_meta`` block per language. Languages with less data should use a higher ``sample_weight`` to compensate: +You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. + +.. code-block:: bash + + # High-resource languages — standard weight + +train_ds_meta.spanish.manifest_path=/path/to/spanish_train.json \ + +train_ds_meta.spanish.audio_dir=/path/to/spanish_audio \ + +train_ds_meta.spanish.feature_dir=/path/to/spanish_audio \ + +train_ds_meta.spanish.sample_weight=1.0 \ + "+train_ds_meta.spanish.tokenizer_names=[chartokenizer]" \ + +train_ds_meta.french.manifest_path=/path/to/french_train.json \ + +train_ds_meta.french.audio_dir=/path/to/french_audio \ + +train_ds_meta.french.feature_dir=/path/to/french_audio \ + +train_ds_meta.french.sample_weight=1.0 \ + "+train_ds_meta.french.tokenizer_names=[chartokenizer]" \ + # Low-resource language — upsampled 10x + +train_ds_meta.low_resource_lang.manifest_path=/path/to/low_resource_train.json \ + +train_ds_meta.low_resource_lang.audio_dir=/path/to/low_resource_audio \ + +train_ds_meta.low_resource_lang.feature_dir=/path/to/low_resource_audio \ + +train_ds_meta.low_resource_lang.sample_weight=10.0 \ + "+train_ds_meta.low_resource_lang.tokenizer_names=[chartokenizer]" + +The ``model.load_cached_codes_if_available=true`` flag skips re-computing audio codes at training time when they are already stored in the manifest. This can substantially reduce data loading overhead when audio codes are pre-extracted. + + +Preference Optimization After Finetuning +######################################### + +After supervised finetuning, you can further improve output quality with GRPO (Group Relative Policy Optimization). GRPO generates multiple candidate outputs per item online, scores them with automatic metrics (CER, SSIM, PESQ), and trains the model to prefer higher-scoring outputs. + +To run GRPO starting from a finetuned checkpoint: + +.. code-block:: bash + + python examples/tts/magpietts.py \ + --config-name=magpietts_multilingual_v1 \ + +init_from_ptl_ckpt=/path/to/finetuned.ckpt \ + +mode=onlinepo_train \ + +model.loss_type=grpo \ + +model.num_generations_per_item=12 \ + +model.cer_reward_weight=0.45 \ + +model.ssim_reward_weight=0.45 \ + +model.pesq_reward_weight=0.1 \ + +model.use_pesq=true \ + +model.reward_asr_model=whisper \ + model.optim.lr=2e-7 \ + ~model.optim.sched \ + trainer.precision=32 \ + trainer.devices=8 \ + trainer.num_nodes=1 + +For full GRPO configuration options and the complete pipeline, see :doc:`Magpie-TTS Preference Optimization `. + + +Key Hyperparameter Reference +############################# + +.. list-table:: + :widths: 35 25 40 + :header-rows: 1 + + * - Parameter + - Typical Value + - Notes + * - ``model.optim.lr`` + - ``5e-6`` (English SFT), ``1e-5`` (multilingual) + - Much lower than pretraining LR to preserve learned features + * - ``trainer.max_steps`` + - ``10000``–``15000`` + - Shorter runs for small datasets; monitor validation loss + * - ``model.alignment_loss_scale`` + - ``0.0`` + - Disable alignment prior during finetuning + * - ``model.prior_scaling_factor`` + - ``null`` + - Disable alignment prior during finetuning + * - ``trainer.precision`` + - ``32`` + - Recommended for finetuning stability + * - ``model.cfg_unconditional_prob`` + - ``0.1`` + - Classifier-free guidance dropout rate during training + diff --git a/docs/source/tts/models.rst b/docs/source/tts/models.rst index 1a4cd8032f01..7ff1872ee8d4 100644 --- a/docs/source/tts/models.rst +++ b/docs/source/tts/models.rst @@ -33,7 +33,7 @@ End-to-End LLM-based TTS MagpieTTS ~~~~~~~~~ -MagpieTTS is an encoder-decoder transformer TTS model that operates on discrete audio tokens from a neural audio codec. It uses monotonic alignment (CTC loss and attention priors) to reduce hallucinations and supports voice cloning via audio or text context conditioning. For architecture, training, inference, and preference optimization (DPO/GRPO), see :doc:`Magpie-TTS documentation `. +MagpieTTS is an encoder-decoder transformer TTS model that operates on discrete audio tokens from a neural audio codec. It uses monotonic alignment (CTC loss and attention priors) to reduce hallucinations and supports voice cloning via audio or text context conditioning. For architecture, training, inference, and preference optimization (DPO/GRPO), see :doc:`Magpie-TTS documentation `. To adapt a pretrained checkpoint to new speakers or new languages, see :doc:`Magpie-TTS Finetuning `. Vocoders From 7b4fc36e94ec29f3dbe1b16b24e8376641d155f8 Mon Sep 17 00:00:00 2001 From: subhankar-ghosh Date: Tue, 31 Mar 2026 09:03:09 -0700 Subject: [PATCH 02/10] Finetuning docs review changes Signed-off-by: subhankar-ghosh --- docs/source/tts/magpietts-finetuning.rst | 140 ++++++++++------------- 1 file changed, 60 insertions(+), 80 deletions(-) diff --git a/docs/source/tts/magpietts-finetuning.rst b/docs/source/tts/magpietts-finetuning.rst index 35e7e9844fbe..4c58da1ae990 100644 --- a/docs/source/tts/magpietts-finetuning.rst +++ b/docs/source/tts/magpietts-finetuning.rst @@ -26,17 +26,11 @@ Before finetuning, you will need: Dataset Preparation ------------------- -For faster finetuning, Magpie-TTS audio codec tokens can be pre-computed and stored alongside each audio file. Run the codec model over your audio dataset to generate and cache the audio codes before launching any training job: +Training uses ``MagpieTTSDataset`` with ``dataset_meta`` entries (see ``DatasetMeta`` in ``nemo/collections/tts/data/text_to_speech_dataset.py``). Each line in ``manifest_path`` file is one training example. -.. code-block:: bash - - python scripts/magpietts/extract_audio_codes.py \ - --manifest_path /path/to/your_manifest.json \ - --audio_dir /path/to/audio \ - --codecmodel_path nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps \ - --output_manifest /path/to/your_manifest_withAudioCodes.json +**Optional cached codec codes.** If each line includes ``target_audio_codes_path`` and ``context_audio_codes_path`` (paths to saved tensors) and ``model.load_cached_codes_if_available=true``, the dataloader can skip on-the-fly codec encoding. If those keys are absent, the codec runs during training and loads waveform from ``audio_filepath`` and ``context_audio_filepath`` (slower but no separate extraction step). -Each manifest entry should be a JSON line with at minimum: +**Minimum fields** (paths relative to ``audio_dir`` / ``feature_dir`` unless you use absolute paths): .. code-block:: json @@ -45,12 +39,14 @@ Each manifest entry should be a JSON line with at minimum: "text": "transcript of the utterance", "duration": 5.2, "context_audio_filepath": "relative/path/to/context.wav", - "context_text": "transcript of the context audio" + "context_text": "transcript of the context audio", + "target_audio_codes_path": "/optional/path/to/target_codes.pt", + "context_audio_codes_path": "/optional/path/to/context_codes.pt" } -The ``context_audio_filepath`` is the reference audio that the model uses for voice cloning during training. It should come from the same speaker as ``audio_filepath``. A minimum context duration of 3 seconds and a speaker similarity score of at least 0.6 (measured by TitaNet) are recommended for best results. +The ``context_audio_filepath`` is the reference audio used for voice cloning during training. It should come from the same speaker as ``audio_filepath``. A minimum context duration of about 3 seconds and a high speaker similarity (for example ≥ 0.6 with TitaNet) are recommended for best results. -For Lhotse-style dataset loading (used in the English SFT scenario), manifest entries are organized into a YAML ``input_cfg`` file instead of being passed directly. See the Lhotse configuration examples below. +**Registering datasets in config.** For each named split in ``train_ds_meta`` and ``val_ds_meta``, set ``manifest_path``, ``audio_dir``, ``feature_dir``, ``sample_weight`` (training), and ``tokenizer_names``: a list of keys that exist under ``model.text_tokenizers`` in the config. The dataloader picks the tokenizer for each sample from that list (see ``DatasetMeta``). .. _magpie-tts-new-speaker: @@ -58,50 +54,52 @@ For Lhotse-style dataset loading (used in the English SFT scenario), manifest en Adding New Speakers in an Existing Language ########################################### -This scenario adapts a pretrained checkpoint to a set of target speakers in a specific language already present in the pretrained checkpoint. The model already knows the language; you are teaching it new voice characteristics. You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. - -Key training choices for speaker finetuning: - -- **Low learning rate** (``5e-6``): the pretrained model is already well-converged. A high LR will destroy the learned representations. -- **Disable alignment prior** (``alignment_loss_scale=0.0``, ``prior_scaling_factor=null``): the alignment prior is beneficial during pretraining to enforce monotonicity, but during finetuning it can over-constrain adaptation and hurt quality. -- **Context audio filtering**: ensure each training sample has a high-quality context audio from the same speaker. Use ``min_context_speaker_similarity: 0.6`` in your Lhotse manifest to filter low-quality pairs. +This scenario adapts a pretrained checkpoint to new speakers in a language the model already supports (for example adding new English speakers to a checkpoint trained on English data). You are teaching new voice characteristics while keeping the same text tokenizer. Mixing in some public Magpie-TTS data can reduce regression; see the `Magpie-TTS dataset `_ on Hugging Face. -The dataset uses Lhotse's ``input_cfg`` YAML format, which supports bucketed batching by duration for efficient GPU utilization: +Key training choices: -.. code-block:: yaml +- **Low learning rate** (``5e-6``): the pretrained model is already well-converged; a high LR can destroy learned representations. +- **Disable alignment prior** (``alignment_loss_scale=0.0``, ``prior_scaling_factor=null``): the prior helps pretraining but can over-constrain finetuning. +- **Tokenizer**: use ``tokenizer_names: [english_phoneme]`` (or the tokenizer that matches your transcripts) on each ``train_ds_meta`` / ``val_ds_meta`` entry. - # train_input_cfg.yaml - - input_path: /path/to/your_dataset_shards/ - type: nemo_tarred - weight: 1.0 +``magpietts.yaml`` trains with ``max_epochs`` and top-level ``batch_size``. Validation mixes all ``val_ds_meta`` entries in a single dataloader (joint validation metrics). .. code-block:: bash python examples/tts/magpietts.py \ --config-path=examples/tts/conf/magpietts \ - --config-name=magpietts_en_v2_lhotse \ + --config-name=magpietts \ +init_from_ptl_ckpt=/path/to/pretrained.ckpt \ - +exp_manager.explicit_log_dir=/path/to/output \ + exp_manager.exp_dir=/path/to/output \ + +train_ds_meta.en_sft.manifest_path=/path/to/train.json \ + +train_ds_meta.en_sft.audio_dir=/path/to/audio \ + +train_ds_meta.en_sft.feature_dir=/path/to/features \ + +train_ds_meta.en_sft.sample_weight=1.0 \ + "+train_ds_meta.en_sft.tokenizer_names=[english_phoneme]" \ + +val_ds_meta.en_val.manifest_path=/path/to/val.json \ + +val_ds_meta.en_val.audio_dir=/path/to/audio \ + +val_ds_meta.en_val.feature_dir=/path/to/audio \ + +val_ds_meta.en_val.sample_weight=1.0 \ + "+val_ds_meta.en_val.tokenizer_names=[english_phoneme]" \ model.codecmodel_path=nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps \ - model.train_ds.dataset.input_cfg=/path/to/train_input_cfg.yaml \ - model.train_ds.dataset.batch_duration=500 \ - "+model.train_ds.dataset.bucket_duration_bins=[4.96,5.92,6.8,7.6,8.4,9.2,10.0,10.72,11.46,12.24,13.07,13.92,14.82,15.79,16.8,17.92,18.96,19.6,19.92]" \ - model.validation_ds.dataset.input_cfg=/path/to/val_input_cfg.yaml \ - model.validation_ds.dataset.batch_duration=300 \ - model.optim.lr=5e-6 \ - ~model.optim.sched \ + model.context_duration_min=5.0 \ + model.context_duration_max=5.0 \ model.alignment_loss_scale=0.0 \ model.prior_scaling_factor=null \ - trainer.max_steps=15000 \ + model.optim.lr=5e-6 \ + ~model.optim.sched \ + model.load_cached_codes_if_available=true \ trainer.precision=32 \ trainer.devices=8 \ - trainer.num_nodes=1 + trainer.num_nodes=1 \ + batch_size=16 \ + max_epochs=500 The ``+init_from_ptl_ckpt`` flag loads the pretrained checkpoint weights before training begins. The ``+`` prefix is required because this key is not present in the base config. -``~model.optim.sched`` removes the learning rate schedule so the LR stays constant throughout finetuning. For short finetuning runs a fixed low LR is more stable than a decaying schedule. +``~model.optim.sched`` removes the learning rate schedule so the LR stays constant during finetuning. -``trainer.precision=32`` is recommended for finetuning stability. Mixed precision (``bf16`` or ``16``) can cause loss instability when adapting to small datasets. +``trainer.precision=32`` is recommended for finetuning stability. Mixed precision (``bf16`` or ``16``) can cause loss instability on small datasets. .. _magpie-tts-new-language: @@ -109,20 +107,27 @@ The ``+init_from_ptl_ckpt`` flag loads the pretrained checkpoint weights before Adding a New Language ##################### -This scenario extends the model to synthesize speech in one or more languages not present in the pretraining data. The multilingual finetuning config uses the non-Lhotse ``train_ds_meta`` dataset format, which is better suited for combining multiple language-specific manifests with per-language sample weights. You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. +This scenario extends the model to one or more languages not present in the pretraining data. Use the same ``magpietts`` config and combine multiple manifests with per-language ``sample_weight``. -Key differences from the English SFT scenario: +**Tokenizers** -- **Byte-level tokenizer** (``google/byt5-small``): when adding languages whose characters fall outside the pretraining vocabulary, use a byte-level tokenizer. This avoids out-of-vocabulary tokens for new scripts or phoneme sets and lets the model learn new language representations without modifying the vocabulary. -- **Per-language dataset entries** (``train_ds_meta.``): each language is registered as a separate entry with its own manifest path, audio directory, and ``sample_weight``. This makes it straightforward to control the language balance during training. -- **Sample weight upsampling**: languages with less data can be upsampled via ``sample_weight``. For instance, a language with only 30 minutes of training data might use ``sample_weight=10.0`` alongside a resource-rich language at ``sample_weight=1.0`` to prevent the larger language from dominating training. +- Define each new tokenizer under ``model.text_tokenizers`` (for example an ``AutoTokenizer`` with ``google/byt5-small`` for scripts outside the IPA vocabulary). +- **How it is applied:** each ``train_ds_meta`` / ``val_ds_meta`` entry lists ``tokenizer_names`` (keys under ``model.text_tokenizers``). The dataloader uses those names to select which tokenizer encodes each sample’s transcript (see ``DatasetMeta`` in ``nemo/collections/tts/data/text_to_speech_dataset.py``). -Dataset manifest entries for multilingual finetuning use the IPA phoneme representation. Ensure your manifests include IPA-transcribed text fields. Audio codes must be pre-extracted as described in the *Dataset Preparation* section. +**Per-language entries** + +Each language is a separate key under ``train_ds_meta`` / ``val_ds_meta`` with ``manifest_path``, ``audio_dir``, ``feature_dir``, ``sample_weight``, and ``tokenizer_names``. + +**Sample weights** + +Upsample low-resource languages with a higher ``sample_weight`` so they are not drowned out by high-resource languages. + +Align transcript format with the tokenizer you choose (IPA phonemes for ``english_phoneme`` / IPA-style tokenizers, raw text for byte-level models, and so on). Audio codes can be cached as in *Dataset Preparation*. .. code-block:: bash python examples/tts/magpietts.py \ - --config-name=magpietts_multilingual_v1 \ + --config-name=magpietts \ +init_from_ptl_ckpt=/path/to/pretrained.ckpt \ exp_manager.exp_dir=/path/to/output \ +model.text_tokenizers.your_language_chartokenizer._target_=AutoTokenizer \ @@ -154,8 +159,7 @@ Dataset manifest entries for multilingual finetuning use the IPA phoneme represe Mixing Multiple Languages -------------------------- -To train on several languages simultaneously, add one ``train_ds_meta`` block per language. Languages with less data should use a higher ``sample_weight`` to compensate: -You should consider mixing the new data with some of the publicly available existing data to prevent degradation of the model. You can find the publicly available data in the `Magpie-TTS dataset `_ on Hugging Face. +Add one ``train_ds_meta`` entry per language. Increase ``sample_weight`` for low-resource languages. You can mix public Magpie-TTS data with your own; see the `Magpie-TTS dataset `_ on Hugging Face. .. code-block:: bash @@ -164,49 +168,26 @@ You should consider mixing the new data with some of the publicly available exis +train_ds_meta.spanish.audio_dir=/path/to/spanish_audio \ +train_ds_meta.spanish.feature_dir=/path/to/spanish_audio \ +train_ds_meta.spanish.sample_weight=1.0 \ - "+train_ds_meta.spanish.tokenizer_names=[chartokenizer]" \ + "+train_ds_meta.spanish.tokenizer_names=[spanish_phoneme_or_chartokenizer]" \ +train_ds_meta.french.manifest_path=/path/to/french_train.json \ +train_ds_meta.french.audio_dir=/path/to/french_audio \ +train_ds_meta.french.feature_dir=/path/to/french_audio \ +train_ds_meta.french.sample_weight=1.0 \ - "+train_ds_meta.french.tokenizer_names=[chartokenizer]" \ + "+train_ds_meta.french.tokenizer_names=[french_chartokenizer]" \ # Low-resource language — upsampled 10x +train_ds_meta.low_resource_lang.manifest_path=/path/to/low_resource_train.json \ +train_ds_meta.low_resource_lang.audio_dir=/path/to/low_resource_audio \ +train_ds_meta.low_resource_lang.feature_dir=/path/to/low_resource_audio \ - +train_ds_meta.low_resource_lang.sample_weight=10.0 \ - "+train_ds_meta.low_resource_lang.tokenizer_names=[chartokenizer]" + +train_ds_meta.low_resource_lang.sample_weight=5.0 \ + "+train_ds_meta.low_resource_lang.tokenizer_names=[low_resource_chartokenizer]" -The ``model.load_cached_codes_if_available=true`` flag skips re-computing audio codes at training time when they are already stored in the manifest. This can substantially reduce data loading overhead when audio codes are pre-extracted. +With ``model.load_cached_codes_if_available=true``, precomputed ``target_audio_codes_path`` / ``context_audio_codes_path`` in the manifest avoid recomputing codec codes at train time. Preference Optimization After Finetuning ######################################### -After supervised finetuning, you can further improve output quality with GRPO (Group Relative Policy Optimization). GRPO generates multiple candidate outputs per item online, scores them with automatic metrics (CER, SSIM, PESQ), and trains the model to prefer higher-scoring outputs. - -To run GRPO starting from a finetuned checkpoint: - -.. code-block:: bash - - python examples/tts/magpietts.py \ - --config-name=magpietts_multilingual_v1 \ - +init_from_ptl_ckpt=/path/to/finetuned.ckpt \ - +mode=onlinepo_train \ - +model.loss_type=grpo \ - +model.num_generations_per_item=12 \ - +model.cer_reward_weight=0.45 \ - +model.ssim_reward_weight=0.45 \ - +model.pesq_reward_weight=0.1 \ - +model.use_pesq=true \ - +model.reward_asr_model=whisper \ - model.optim.lr=2e-7 \ - ~model.optim.sched \ - trainer.precision=32 \ - trainer.devices=8 \ - trainer.num_nodes=1 - -For full GRPO configuration options and the complete pipeline, see :doc:`Magpie-TTS Preference Optimization `. +After supervised finetuning, you can further improve quality with GRPO. For commands and hyperparameters, see :doc:`Magpie-TTS Preference Optimization ` (the GRPO example uses ``--config-name=magpietts`` with ``+mode=onlinepo_train``). Key Hyperparameter Reference @@ -220,10 +201,10 @@ Key Hyperparameter Reference - Typical Value - Notes * - ``model.optim.lr`` - - ``5e-6`` (English SFT), ``1e-5`` (multilingual) + - ``5e-6`` (same-language speakers), ``1e-5`` (multilingual) - Much lower than pretraining LR to preserve learned features - * - ``trainer.max_steps`` - - ``10000``–``15000`` + * - ``max_epochs`` + - tens to hundreds - Shorter runs for small datasets; monitor validation loss * - ``model.alignment_loss_scale`` - ``0.0`` @@ -237,4 +218,3 @@ Key Hyperparameter Reference * - ``model.cfg_unconditional_prob`` - ``0.1`` - Classifier-free guidance dropout rate during training - From 2244fa04b586bde2608c7c22b86162f9d090247e Mon Sep 17 00:00:00 2001 From: Charlie Truong Date: Wed, 25 Mar 2026 17:05:07 -0500 Subject: [PATCH 03/10] ci: Update docs build job to exclude cu12 extra (#15553) * Test no extras docs build Signed-off-by: Charlie Truong * ci: Update docs job to use 0.83.0 templates Signed-off-by: Charlie Truong * Uncomment push cases for build-docs github action Signed-off-by: Charlie Truong --------- Signed-off-by: Charlie Truong --- .github/workflows/build-docs.yml | 3 ++- .github/workflows/release-docs.yml | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-docs.yml b/.github/workflows/build-docs.yml index 45f9e80941ab..af7419d6945d 100644 --- a/.github/workflows/build-docs.yml +++ b/.github/workflows/build-docs.yml @@ -39,10 +39,11 @@ jobs: build-docs: needs: [pre-flight] if: needs.pre-flight.outputs.is_deployment_workflow != 'true' - uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_build_docs.yml@v0.80.2 + uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_build_docs.yml@v0.83.0 with: docs-directory: docs/source sync-all: true + no-extras: "--no-extra cu12" build-docs-summary: needs: [pre-flight, build-docs] diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml index 37114f899b57..c0eb33ede883 100644 --- a/.github/workflows/release-docs.yml +++ b/.github/workflows/release-docs.yml @@ -78,11 +78,12 @@ on: jobs: build-docs: - uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_build_docs.yml@v0.74.0 + uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_build_docs.yml@v0.83.0 with: ref: ${{ inputs.github-ref }} docs-directory: docs/source sync-all: true + no-extras: "--no-extra cu12" publish-docs: runs-on: ubuntu-latest From 7dd32cfcff25eebc7dcae87e0a4a68b561990175 Mon Sep 17 00:00:00 2001 From: Subhankar Ghosh Date: Wed, 25 Mar 2026 19:58:35 -0400 Subject: [PATCH 04/10] Rename index for attention prior weights (#15551) Signed-off-by: Subhankar Ghosh --- nemo/collections/tts/models/magpietts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 63b6958e5b9d..d9794338ffe6 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -4210,9 +4210,9 @@ def _initialize_chunked_attn_prior( prior_weights = self.chunked_inference_config.prior_weights_init _attn_prior[_idx, :, :current_starting_point] = prior_epsilon * prior_epsilon for offset, weight in enumerate(prior_weights[:5]): - idx = current_starting_point + offset - if idx < max_text_len: - _attn_prior[_idx, :, idx] = weight + current_offset_idx = current_starting_point + offset + if current_offset_idx < max_text_len: + _attn_prior[_idx, :, current_offset_idx] = weight return _attn_prior From 15249c0fe9084b2823b49cb4a071b89c421f6143 Mon Sep 17 00:00:00 2001 From: Vladimir Bataev Date: Thu, 26 Mar 2026 13:18:13 +0400 Subject: [PATCH 05/10] Ignore PnC for WER calculation: streaming ASR inference (#15550) Signed-off-by: Vladimir Bataev --- .../rnnt/speech_to_text_streaming_infer_rnnt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 298a3d5b4b55..c7aa2327a314 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -485,6 +485,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: langid=cfg.langid, use_cer=cfg.use_cer, output_filename=None, + ignore_punctuation=True, + ignore_capitalization=True, ) if output_manifest_w_wer: logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") From 591c9ed0f18b18e5582760b18c038067c4c99bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Mon, 30 Mar 2026 10:27:28 +0200 Subject: [PATCH 06/10] ci: upgrade GitHub Actions for Node.js 24 compatibility (#15537) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upgrades actions to versions compatible with the Node.js 24 runtime: - actions/checkout: → v6 - actions/upload-artifact: → v6 - actions/download-artifact: → v7 - actions/github-script: → v8 - actions/setup-python: → v6 Mirrors: https://github.com/NVIDIA/Megatron-LM/commit/1d5e68b0749f0fc075250fae4e36081d972379a8 Signed-off-by: oliver könig --- .github/actions/test-template/action.yml | 4 ++-- .github/workflows/_build_container.yml | 2 +- .github/workflows/_bump_mcore_tag.yml | 2 +- .github/workflows/cicd-approve-test-queue.yml | 4 ++-- .github/workflows/cicd-main-speech.yml | 6 +++--- .github/workflows/cicd-main-unit-tests.yml | 6 +++--- .github/workflows/cicd-main.yml | 16 +++++++-------- .github/workflows/cicd-relabel-bot.yml | 2 +- .github/workflows/claude-answer.yml | 6 +++--- .github/workflows/claude-fix.yml | 8 ++++---- .github/workflows/claude-review.yml | 2 +- .github/workflows/code-formatting.yml | 4 ++-- .github/workflows/code-init-file-checker.yml | 4 ++-- .github/workflows/code-linting.yml | 2 +- .github/workflows/codeql.yml | 2 +- .github/workflows/install-test.yml | 20 +++++++++---------- .github/workflows/release-freeze.yml | 2 +- .github/workflows/secrets-detector.yml | 2 +- .github/workflows/update-buildcache.yml | 2 +- 19 files changed, 48 insertions(+), 48 deletions(-) diff --git a/.github/actions/test-template/action.yml b/.github/actions/test-template/action.yml index 2213a6e100f8..9ec8ce51d1db 100644 --- a/.github/actions/test-template/action.yml +++ b/.github/actions/test-template/action.yml @@ -83,7 +83,7 @@ runs: echo "id=$(uuidgen)" >> "$GITHUB_OUTPUT" - name: Checkout NeMo - uses: actions/checkout@v4 + uses: actions/checkout@v6 env: DIR: ${{ github.run_id }} with: @@ -213,7 +213,7 @@ runs: docker exec -t nemo_container_${{ github.run_id }}_${{ inputs.runner }} coverage report -i - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 if: ${{ steps.check.outputs.coverage_report != 'none' }} with: name: ${{ steps.check.outputs.coverage_report }} diff --git a/.github/workflows/_build_container.yml b/.github/workflows/_build_container.yml index 97d4a2a960c2..5d877fe59028 100644 --- a/.github/workflows/_build_container.yml +++ b/.github/workflows/_build_container.yml @@ -23,7 +23,7 @@ jobs: cache-from: ${{ steps.cache_from.outputs.LAST_PRS }} steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Parse manifest.json id: manifest diff --git a/.github/workflows/_bump_mcore_tag.yml b/.github/workflows/_bump_mcore_tag.yml index 0789a5f76005..56e316f933b9 100644 --- a/.github/workflows/_bump_mcore_tag.yml +++ b/.github/workflows/_bump_mcore_tag.yml @@ -18,7 +18,7 @@ jobs: update-branch: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v6 with: ref: ${{ inputs.nemo-target-branch }} diff --git a/.github/workflows/cicd-approve-test-queue.yml b/.github/workflows/cicd-approve-test-queue.yml index ddfb07e367da..608005072b5d 100644 --- a/.github/workflows/cicd-approve-test-queue.yml +++ b/.github/workflows/cicd-approve-test-queue.yml @@ -25,10 +25,10 @@ jobs: environment: main steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.12" diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 49218116c001..cc51bbaa7d05 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -64,7 +64,7 @@ jobs: name: ${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main @@ -196,7 +196,7 @@ jobs: name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main @@ -460,7 +460,7 @@ jobs: name: ${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main diff --git a/.github/workflows/cicd-main-unit-tests.yml b/.github/workflows/cicd-main-unit-tests.yml index 4a4e7621e1bd..e15ec06e17b3 100644 --- a/.github/workflows/cicd-main-unit-tests.yml +++ b/.github/workflows/cicd-main-unit-tests.yml @@ -34,7 +34,7 @@ jobs: name: ${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main @@ -65,7 +65,7 @@ jobs: name: ${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main @@ -91,7 +91,7 @@ jobs: name: ${{ matrix.script }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} - name: main diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 329bb00f897e..2e860621e523 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -48,7 +48,7 @@ jobs: HAS_LABEL: ${{ github.event.label.name == 'Run CICD' }} steps: - name: Checkout branch - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -171,7 +171,7 @@ jobs: echo "id=$(uuidgen)" >> "$GITHUB_OUTPUT" - name: Checkout NeMo - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }}/${{steps.uuid.outputs.id }}/NeMo @@ -212,7 +212,7 @@ jobs: && !cancelled() steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} @@ -274,7 +274,7 @@ jobs: permissions: write-all steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Get workflow result id: result @@ -301,7 +301,7 @@ jobs: echo "code=$RESULT" | tee -a $GITHUB_OUTPUT - name: Checkout for GH CLI - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Remove label if not cancelled if: | @@ -359,10 +359,10 @@ jobs: flag: [unit-test, e2e] steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Download coverage reports of current branch - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: pattern: coverage-${{ matrix.flag }}-* @@ -387,7 +387,7 @@ jobs: flags: ${{ matrix.flag }} - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: coverage-${{ matrix.flag }}-aggregated path: | diff --git a/.github/workflows/cicd-relabel-bot.yml b/.github/workflows/cicd-relabel-bot.yml index 7f4eaf63b9a3..c7546df322a2 100644 --- a/.github/workflows/cicd-relabel-bot.yml +++ b/.github/workflows/cicd-relabel-bot.yml @@ -17,7 +17,7 @@ jobs: permissions: write-all steps: - name: Checkout repo - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Check if PR was already labeled with `Run CICD` id: pre-flight diff --git a/.github/workflows/claude-answer.yml b/.github/workflows/claude-answer.yml index 29016e6c49ee..c6a2ed7d2e7a 100644 --- a/.github/workflows/claude-answer.yml +++ b/.github/workflows/claude-answer.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check team membership - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.ORG_TEAM_READ_TOKEN }} script: | @@ -40,7 +40,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add eyes reaction to comment - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | await github.rest.reactions.createForIssueComment({ @@ -54,7 +54,7 @@ jobs: needs: acknowledge runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: anthropics/claude-code-action@v1 with: prompt: | diff --git a/.github/workflows/claude-fix.yml b/.github/workflows/claude-fix.yml index 22ba31ee3413..1f7f208ea5bb 100644 --- a/.github/workflows/claude-fix.yml +++ b/.github/workflows/claude-fix.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check team membership - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.ORG_TEAM_READ_TOKEN }} script: | @@ -41,7 +41,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add eyes reaction to comment - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | await github.rest.reactions.createForIssueComment({ @@ -55,7 +55,7 @@ jobs: needs: acknowledge runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - uses: anthropics/claude-code-action@v1 id: claude with: @@ -93,7 +93,7 @@ jobs: - name: Label PR with agent-contribution if: steps.claude.outputs.branch_name - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | const prs = await github.rest.pulls.list({ diff --git a/.github/workflows/claude-review.yml b/.github/workflows/claude-review.yml index b792a28d9417..e4f3eb453399 100644 --- a/.github/workflows/claude-review.yml +++ b/.github/workflows/claude-review.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add eyes reaction to comment - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | await github.rest.reactions.createForIssueComment({ diff --git a/.github/workflows/code-formatting.yml b/.github/workflows/code-formatting.yml index 90ab8177cc3e..2fc272bbab65 100644 --- a/.github/workflows/code-formatting.yml +++ b/.github/workflows/code-formatting.yml @@ -27,7 +27,7 @@ jobs: contents: write steps: - name: Checkout branch - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: # setup repository and ref for PRs, see # https://github.com/EndBug/add-and-commit?tab=readme-ov-file#working-with-prs @@ -45,7 +45,7 @@ jobs: **.py - name: Setup Python env - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" diff --git a/.github/workflows/code-init-file-checker.yml b/.github/workflows/code-init-file-checker.yml index ecb55f953a71..7d0444b136af 100644 --- a/.github/workflows/code-init-file-checker.yml +++ b/.github/workflows/code-init-file-checker.yml @@ -9,10 +9,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6 with: python-version: "3.11" diff --git a/.github/workflows/code-linting.yml b/.github/workflows/code-linting.yml index 0ed4bfd813ae..eff0a71756ac 100644 --- a/.github/workflows/code-linting.yml +++ b/.github/workflows/code-linting.yml @@ -17,7 +17,7 @@ jobs: DOMAIN: ${{ matrix.domain }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Select filter id: filter diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index d82e99872853..28c060be0990 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/install-test.yml b/.github/workflows/install-test.yml index ad43a906f9a3..1822f4fbf6b1 100644 --- a/.github/workflows/install-test.yml +++ b/.github/workflows/install-test.yml @@ -21,7 +21,7 @@ jobs: installer: ["pip-install", "nemo-install"] steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v6 - name: Check disk space before cleanup run: df -h @@ -39,7 +39,7 @@ jobs: - name: Check disk space after cleanup run: df -h - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: "${{ matrix.python }}" @@ -80,7 +80,7 @@ jobs: installer: ["pip-install", "nemo-install"] steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v6 - name: Check disk space before cleanup run: df -h @@ -102,7 +102,7 @@ jobs: run: df -h - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} @@ -137,7 +137,7 @@ jobs: python: ["3.10", "3.11", "3.12"] steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v6 - name: Check disk space before cleanup run: df -h @@ -159,7 +159,7 @@ jobs: run: df -h - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} @@ -186,7 +186,7 @@ jobs: installer: ["pip-install", "nemo-install"] steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v6 - name: Check disk space before cleanup run: df -h @@ -208,7 +208,7 @@ jobs: run: df -h - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} @@ -243,7 +243,7 @@ jobs: python: ["3.10", "3.11", "3.12"] steps: - name: Checkout repo - uses: actions/checkout@v2 + uses: actions/checkout@v6 - name: Check disk space before cleanup run: df -h @@ -265,7 +265,7 @@ jobs: run: df -h - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python }} diff --git a/.github/workflows/release-freeze.yml b/.github/workflows/release-freeze.yml index 836135f37c88..6462d2fe65d4 100644 --- a/.github/workflows/release-freeze.yml +++ b/.github/workflows/release-freeze.yml @@ -45,7 +45,7 @@ jobs: environment: main steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: ${{ github.run_id }} token: ${{ secrets.PAT }} diff --git a/.github/workflows/secrets-detector.yml b/.github/workflows/secrets-detector.yml index 3fb64e0155b7..243cf5af32c1 100644 --- a/.github/workflows/secrets-detector.yml +++ b/.github/workflows/secrets-detector.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 token: ${{ secrets.NEMO_REFORMAT_TOKEN }} diff --git a/.github/workflows/update-buildcache.yml b/.github/workflows/update-buildcache.yml index 7445d472a2ff..209a1f5bdb96 100644 --- a/.github/workflows/update-buildcache.yml +++ b/.github/workflows/update-buildcache.yml @@ -34,7 +34,7 @@ jobs: cache-from: ${{ steps.cache_from.outputs.LAST_PRS }} steps: - name: Checkout branch - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Parse manifest.json id: manifest From 1b1c664ae9ed3748f54615b726918c7cdddce8c6 Mon Sep 17 00:00:00 2001 From: Jason Date: Tue, 31 Mar 2026 08:54:22 -0400 Subject: [PATCH 07/10] Add VoiceChat to README (#15547) * Update README.md Signed-off-by: Jason * Revise Nemotron VoiceChat release details in README Updated the release information for Nemotron VoiceChat and added details about its features and early access. Signed-off-by: zhehuaichen <139396994+zhehuaichen@users.noreply.github.com> --------- Signed-off-by: Jason Signed-off-by: zhehuaichen <139396994+zhehuaichen@users.noreply.github.com> Co-authored-by: zhehuaichen <139396994+zhehuaichen@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 573cdfa90655..4c926c9a9950 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ weight checkpoints and demos! ## Updates +- 2026-03: [Nemotron 3 VoiceChat](https://build.nvidia.com/nvidia/nemotron-voicechat/modelcard) is now released in Early Access. Built on the Nemotron Nano v2 LLM backbone with Nemotron speech and TTS decoder, VoiceChat delivers full-duplex, natural, interruptible conversations with low latency. Try out [the demo](https://build.nvidia.com/nvidia/nemotron-voicechat) and apply for [early access](https://developer.nvidia.com/nemotron-voicechat-early-access). - 2026-03: [Nemotron-Speech-Streaming v2603](https://huggingface.co/nvidia/nemotron-speech-streaming-en-0.6b) has been updated. It has been trained on a larger and more diverse corpus, resulting in lower WER across all latency modes. Try out [the demo](https://huggingface.co/spaces/nvidia/nemotron-speech-streaming-en-0.6b) and check out From a57ac03476148d2eda3b01b6b4c5af71899b57a1 Mon Sep 17 00:00:00 2001 From: "He Huang (Steve)" <105218074+stevehuang52@users.noreply.github.com> Date: Tue, 31 Mar 2026 10:22:12 -0400 Subject: [PATCH 08/10] Add ASR-EOU models and training/eval scripts (#14740) * initial commit for end-of-utterance detection Signed-off-by: Weiqing Wang * change targets to long() type Signed-off-by: Weiqing Wang * change output_types() Signed-off-by: Weiqing Wang * add random padding and refactor for multiple utterances per sample Signed-off-by: stevehuang52 * add handling multiple text groundtruth Signed-off-by: stevehuang52 * update and add eval scripts Signed-off-by: stevehuang52 * drop sou label and add eob label Signed-off-by: stevehuang52 * update hybrid-rnnt-ctc and rnnt models to use eou dataset Signed-off-by: stevehuang52 * set default return eou frame label to false Signed-off-by: stevehuang52 * handle empty utterance Signed-off-by: stevehuang52 * add script for injecting special eou tokens into SPE tokenizer Signed-off-by: stevehuang52 * refactor eou eval utils Signed-off-by: stevehuang52 * add eou rnnt training Signed-off-by: stevehuang52 * update doc Signed-off-by: stevehuang52 * update data augmentation Signed-off-by: stevehuang52 * update data related functions Signed-off-by: stevehuang52 * fix tokenizer with eou tokens Signed-off-by: stevehuang52 * adding eou force aligner Signed-off-by: Weiqing Wang * update for eou Signed-off-by: stevehuang52 * fix the case when 'segments_level_ctm_filepath' is not produced Signed-off-by: Weiqing Wang * fix force aligner Signed-off-by: stevehuang52 * fix aligner Signed-off-by: stevehuang52 * update for asr-eou Signed-off-by: stevehuang52 * clean up and update infer Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * fix rnnt_decoding for empty string Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update padding augment Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * fix eob metric logging Signed-off-by: stevehuang52 * refactor and add hybrid model Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update EOU models Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * refactor percentile calculation Signed-off-by: stevehuang52 * update augmentation Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update model and cfg Signed-off-by: stevehuang52 * update frame eou Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * add adapter to eou Signed-off-by: stevehuang52 * remove pdb Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add cfg Signed-off-by: stevehuang52 * fix eou metric Signed-off-by: stevehuang52 * update adapter Signed-off-by: stevehuang52 * add scripts Signed-off-by: stevehuang52 * update docstring Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update generate eval data Signed-off-by: stevehuang52 * update eou val Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add drop_pnc=true as default for dataloading Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * fix miss rate Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add ignore_eob_label Signed-off-by: stevehuang52 * fix and update Signed-off-by: stevehuang52 * improve lhotse augmentation Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add debug info Signed-off-by: stevehuang52 * improve data augmentation Signed-off-by: stevehuang52 * update utils Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update dataloader Signed-off-by: stevehuang52 * update oomptimizer Signed-off-by: stevehuang52 * update oomptimizer Signed-off-by: stevehuang52 * update eou model Signed-off-by: stevehuang52 * update eou model Signed-off-by: stevehuang52 * update eou model Signed-off-by: stevehuang52 * update augmentation Signed-off-by: stevehuang52 * update aug Signed-off-by: stevehuang52 * update augment Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update drop pnc func Signed-off-by: stevehuang52 * update eou finetune Signed-off-by: stevehuang52 * update transcribe Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * fix cfg Signed-off-by: stevehuang52 * clean up for PR Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * Potential fix for code scanning alert no. 16191: Explicit returns mixed with implicit (fall through) returns Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Potential fix for code scanning alert no. 16190: Explicit returns mixed with implicit (fall through) returns Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Apply isort and black reformatting Signed-off-by: stevehuang52 * Potential fix for code scanning alert no. 16185: File is not always closed Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * clean up Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * fix pylint&flake8 Signed-off-by: stevehuang52 * fix pylint Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * update pr Signed-off-by: stevehuang52 * update adapter Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * update readme, test, etc Signed-off-by: He Huang * Apply isort and black reformatting Signed-off-by: stevehuang52 * update doc Signed-off-by: He Huang * clean up Signed-off-by: He Huang * fix and rename Signed-off-by: He Huang * update doc Signed-off-by: He Huang * clean up Signed-off-by: He Huang * move all length aug to invalid Signed-off-by: He Huang * fix typo Signed-off-by: He Huang * rename and move to scripts/asr_eou Signed-off-by: He Huang * fix ci Signed-off-by: He Huang * fix ci Signed-off-by: He Huang * clean up Signed-off-by: He Huang * clean up Signed-off-by: He Huang * fix linting Signed-off-by: He Huang * fix ci Signed-off-by: He Huang * Apply isort and black reformatting Signed-off-by: stevehuang52 * Potential fix for code scanning alert no. 17270: Explicit export is not defined Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Potential fix for code scanning alert no. 17271: Explicit export is not defined Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * Potential fix for code scanning alert no. 17272: Explicit export is not defined Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> --------- Signed-off-by: Weiqing Wang Signed-off-by: stevehuang52 Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Signed-off-by: stevehuang52 Signed-off-by: He Huang Co-authored-by: Weiqing Wang Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: stevehuang52 --- docker/Dockerfile.speech | 2 +- examples/asr/asr_eou/README.md | 304 ++++ .../asr/asr_eou/speech_to_text_eou_eval.py | 112 ++ .../asr_eou/speech_to_text_rnnt_eou_train.py | 350 ++++ .../helpers/convert_nemo_asr_hybrid_to_ctc.py | 2 +- ...rmer_transducer_bpe_streaming_adapter.yaml | 375 +++++ ...former_transducer_bpe_streaming_large.yaml | 338 ++++ ...ormer_transducer_bpe_streaming_xlarge.yaml | 338 ++++ examples/asr/transcribe_speech.py | 2 +- .../asr/data/audio_to_eou_label_lhotse.py | 524 ++++++ nemo/collections/asr/losses/ssl_losses/mlm.py | 11 +- nemo/collections/asr/metrics/wer.py | 24 +- nemo/collections/asr/models/asr_eou_models.py | 967 +++++++++++ nemo/collections/asr/modules/__init__.py | 41 +- .../asr/modules/conformer_encoder.py | 41 +- nemo/collections/asr/modules/lstm_decoder.py | 14 +- nemo/collections/asr/modules/rnnt.py | 22 + .../asr/modules/ssl_modules/__init__.py | 15 +- .../modules/ssl_modules/multi_layer_feat.py | 92 +- nemo/collections/asr/parts/utils/eou_utils.py | 289 ++++ .../common/data/lhotse/dataloader.py | 2 + scripts/asr_eou/add_eob_labels.py | 222 +++ scripts/asr_eou/clean_manifest.py | 648 ++++++++ scripts/asr_eou/conf/data.yaml | 46 + scripts/asr_eou/eval_eou_metrics.py | 176 ++ scripts/asr_eou/generate_noisy_eval_data.py | 224 +++ .../add_special_tokens_to_sentencepiece.py | 214 +++ .../tokenizers/sentencepiece_model_pb2.py | 1442 +++++++++++++++++ scripts/asr_eou/transcribe_speech_sharded.py | 343 ++++ .../checkpoint_averaging.py | 0 .../convert_to_tarred_audio_dataset.py | 31 +- scripts/speech_recognition/oomptimizer.py | 6 +- tests/collections/asr/test_asr_eou.py | 91 ++ tools/nemo_forced_aligner/align_eou.py | 582 +++++++ tools/nemo_forced_aligner/utils/data_prep.py | 1 + 35 files changed, 7779 insertions(+), 112 deletions(-) create mode 100644 examples/asr/asr_eou/README.md create mode 100644 examples/asr/asr_eou/speech_to_text_eou_eval.py create mode 100644 examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py create mode 100644 examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml create mode 100644 examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_large.yaml create mode 100644 examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml create mode 100644 nemo/collections/asr/data/audio_to_eou_label_lhotse.py create mode 100644 nemo/collections/asr/models/asr_eou_models.py create mode 100644 nemo/collections/asr/parts/utils/eou_utils.py create mode 100644 scripts/asr_eou/add_eob_labels.py create mode 100644 scripts/asr_eou/clean_manifest.py create mode 100644 scripts/asr_eou/conf/data.yaml create mode 100644 scripts/asr_eou/eval_eou_metrics.py create mode 100644 scripts/asr_eou/generate_noisy_eval_data.py create mode 100644 scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py create mode 100644 scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py create mode 100644 scripts/asr_eou/transcribe_speech_sharded.py mode change 100644 => 100755 scripts/checkpoint_averaging/checkpoint_averaging.py create mode 100644 tests/collections/asr/test_asr_eou.py create mode 100644 tools/nemo_forced_aligner/align_eou.py diff --git a/docker/Dockerfile.speech b/docker/Dockerfile.speech index 65e55987dc44..ba32799b00a5 100644 --- a/docker/Dockerfile.speech +++ b/docker/Dockerfile.speech @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.07-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:25.12-py3 # build an image that includes only the nemo dependencies, ensures that dependencies # are included first for optimal caching, and useful for building a development diff --git a/examples/asr/asr_eou/README.md b/examples/asr/asr_eou/README.md new file mode 100644 index 000000000000..728205c79c74 --- /dev/null +++ b/examples/asr/asr_eou/README.md @@ -0,0 +1,304 @@ +# Finetuning streming ASR model for integrated end-of-utterance (EOU) detection + +This tutorial shows how to finetune a streaming ASR model (e.g., [nvidia/nemotron-speech-streaming-en-0.6b](https://huggingface.co/nvidia/nemotron-speech-streaming-en-0.6b)) for integrated EOU detection (e.g., [nvidia/parakeet_realtime_eou_120m-v1](https://huggingface.co/nvidia/parakeet_realtime_eou_120m-v1)). + +We use [Nemotron-Speech-Streaming-En-0.6b](https://huggingface.co/nvidia/nemotron-speech-streaming-en-0.6b) as an example of pretrained ASR model. + +## Steps + +1. Model preparation +2. Dataset preparation +3. Model training +4. Model evaluation + +## 1. Model preparation + +### 1.1. Download pretrained model + +Download the [Nemotron-Speech-Streaming-En-0.6b](https://huggingface.co/nvidia/nemotron-speech-streaming-en-0.6b) model from HuggingFace via: +```bash +wget https://huggingface.co/nvidia/nemotron-speech-streaming-en-0.6b/resolve/main/nemotron-speech-streaming-en-0.6b.nemo +``` + +### 1.2. Add special tokens to tokenizer + +By default, we use `` and `` for "end-of-utterance" and "end-of-backchannel" respectively, and add these two special tokens to the tokenizer of the pretrained model: +```bash +python /scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py \ + --input_file /path/to/nemotron-speech-streaming-en-0.6b.nemo \ + --output_dir /path/to/asr_eou_tokenizer_dir +``` +The output directory `/path/to/asr_eou_tokenizer_dir` will contain the updated tokenizer to be used when updateding the model config. + +The special tokens are added to the end of the original vocabulary. For example, if the original vocabulary size is 1024, the new vocabulary size will be 1026, and the special tokens will be added at the indices 1024 and 1025. + +### 1.3. Update model config for ASR-EOU model + +We can extract the model config from the downloaded .nemo file by: +```bash +tar -xvf /path/to/nemotron-speech-streaming-en-0.6b.nemo -C /path/to/asr_model_dir +``` +The output file `/path/to/asr_model_dir/model_config.yaml` is the model config to be updated for finetuning the ASR model into an ASR-EOU model. + +In the model config file, we need to change the tokenizer to use the new tokenizer with special tokens: +```yaml +tokenizer: + dir: /path/to/asr_eou_tokenizer_dir + type: bpe +``` + +We also need to add some additional configurations to the model section to specify how we want to initialize the weights for the special tokens: +```yaml +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' +``` + +You may also need to change the optimization and loss parameters to suit your use cases. We empirically find that setting `fastemit_lambda` to `3e-2` is a good start. + +```yaml +loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 +``` + +We also need to change the training, validation and test data paths in the model config file based on how we prepare the EOU labeled dataset illustrated in the next section. + +For a full example of the model config file, please refer to `examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml`. + + +## 2. Dataset preparation + +When finetuning the ASR model for EOU detection, we need to prepare the dataset in a specific format. But more importantly, we need to make sure the dataset used for finetuning meets the criteria that each sample should contain a single utterance, otherwise the model's EOU prediction accuracy will be degraded. In the case of using a small EOU dataset, we can blend the EOU dataset with the some normal ASR dataset which does not necessarily contain EOU labels, such that ASR WER is not significantly degraded. For lowest possible EOU latency, we recommend dropping the punctuations from the transcriptions to simplify the text processing. + + +### 2.1 Mainifest format + +We expect the input data manifest to be JSONL format, with each line containing the following fields: +```json +{ + "audio_filepath": "/path/to/audio.wav", + "text": "The text of the audio.", # transcript of the utterance + "offset": 0.0, # offset of the audio, in seconds + "duration": 3.0, # duration of the audio, in seconds + "sou_time": 0.2, # start of utterance time, in seconds + "eou_time": 1.5, # end of utterance time, in seconds + "is_backchannel": false # [optional] whether the utterance is a backchannel phrase + } +``` + +Your original input manifest should contain the fields `audio_filepath`, `text`, `offset` and `duration`, while the fields `sou_time`, `eou_time` and `is_backchannel` can be obtained by following steps. + +### 2.2 Getting timestamps for end-of-utterance (EOU) + +We recommend using forced alignment to get the timestamps for EOU. One way to do this is to use the [Nemo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) tool. + +```bash +python /tools/nemo_forced_aligner/align_eou.py \ + pretrained_name="nvidia/parakeet-ctc-0.6b" \ + manifest_filepath=/path/to/asr_manifest.jsonl \ + output_manifest_filepath=/path/to/asr_eou_manifest.jsonl +``` +The output manifest will contain the fields `audio_filepath`, `text`, `offset`, `duration`, `sou_time` and `eou_time`. + + +### 2.3 (Optional) Add end-of-backchannel (EOB) labels to dataset + +Backchannel phrases refer to those phrases that are not part of the main conversation, but are used to acknowledge or respond to the speaker. For example, "uh-huh", "yeah", "right", "okay", "thanks", "sorry", etc. We can also train the model to detect backchannel phrases by adding end-of-backchannel (EOB) labels to the dataset, so that the cascaded system can leverage the EOU and EOB predictions to better understand the conversation. However, we can also treat EOB as a special case of EOU, and match the predicted EOU phrases with a list of predefined backchannel phrases to prediction EOB, which is more flexible in handling different backchannel phrases. + +If you want to add end-of-backchannel (EOB) labels to training, you can use the following script to add the `is_backchannel` field to the manifest: + +```bash +python /scripts/asr_eou/add_eob_labels.py \ + input_manifest=/path/to/asr_manifest.jsonl \ + output_manifest=/path/to/asr_eou_eob_manifest.jsonl +``` + +The `add_eob_labels.py` file has a list of predefined backchannel phrases, and you can edit it to add more backchannel phrases if needed. An easy way to figure out backchannel phrases is to find the most frequent one, two or three words utterances in the dataset, and manually check if they are backchannel phrases. + +2.4 Creating tarred datasets for large-scale training. + +For more efficient training, you can create tarred datasets for the ASR and EOU datasets by using `scripts/speech_recognition/convert_to_tarred_audio_dataset.py` script. + +### 2.5 Creating input data config for blending ASR and EOU data + +Please refer to the [documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#lhotse-dataloading) for more details on how to specify the dataset configuration in the model config file. + +An example of the `train_input_config.yaml` file is shown below, where we use 0.1 weight for the ASR dataset and 0.9 weight for the EOU dataset. +```yaml +- input_cfg: + - corpus: Librispeech + language: en + type: nemo + manifest_filepath: /data/LibriSpeech/train_other_500.jsonl # this is a normal ASR dataset + tags: + taskname: asr + type: group + weight: 0.1 +- input_cfg: + - corpus: LibriTTS + language: en + type: nemo + manifest_filepath: /data/LibriTTS/train_clean_360_eou.jsonl # this is a EOU manifest after adding sou_time and eou_time fields + tags: + taskname: eou + type: group + weight: 0.9 +``` + + +### 2.6 Creating evaluation dataset + +We can create evaluation dataset by padding the audio signal with non-speech frames and/or adding noise to the clean audios. + +Example usage with multiple manifests matching a pattern: +```bash +python /scripts/asr_eou/generate_noisy_eval_data.py \ + output_dir=/path/to/output/dir \ + data.manifest_filepath=/path/to/manifest/dir/ \ + data.pattern="*.json" \ + data.seed=42 \ + data.noise.manifest_path /path/to/noise_manifest.json +``` + +You can modify the yaml config to specify the augmentation parameters in `scripts/asr_eou/conf/data.yaml`. + + +### 2.7 Configuring dataset in model config + +Now we can update the model config to use the prepared training and evaluation data config. + +```yaml +model: + train_ds: + input_cfg: /path/to/train_input_config.yaml + manifest_filepath: null + tarred_audio_filepaths: null + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + random_padding: + prob: 0.9 + min_post_pad_duration: 3.0 # minimum duration of post-padding silence in seconds + min_pre_pad_duration: 0.0 # minimum duration of pre-padding silence in seconds + max_pad_duration: 6.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: /path/to/noise_manifest.json + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + input_cfg: null + manifest_filepath: ["/path/to/eval_manifest1.json", "/path/to/eval_manifest2.json", ...] + tarred_audio_filepaths: null + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + +``` + +For a full example of the model config file, please refer to `examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml`. + + +## 3. Model training + +To start the training, you can run the following command: +```bash +#!/bin/bash + +TRAIN_INPUT_CFG=/path/to/train_input_config.yaml +VAL_MANIFEST=/path/to/val_manifest.json +NOISE_MANIFEST=/path/to/noise_manifest.json + +PRETRAINED_NEMO=/path/to/nemotron-speech-streaming-en-0.6b.nemo + +BATCH_SIZE=16 +NUM_WORKERS=8 +LIMIT_TRAIN_BATCHES=1000 +VAL_CHECK_INTERVAL=1000 +MAX_STEPS=1000000 + +EXP_NAME=nemotron_speech_streaming_en_0.6b_eou +SCRIPT=${NEMO_PATH}/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py +CONFIG_PATH=${NEMO_PATH}/examples/asr/conf/asr_eou +CONFIG_NAME=fastconformer_transducer_bpe_streaming_xlarge + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT \ + --config-path $CONFIG_PATH \ + --config-name $CONFIG_NAME \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + model.encoder.att_context_size="[70,1]" \ + model.train_ds.input_cfg=$TRAIN_INPUT_CFG \ + model.train_ds.augmentor.noise.manifest_path=$NOISE_MANIFEST \ + model.validation_ds.manifest_filepath=$VAL_MANIFEST \ + model.train_ds.batch_size=$BATCH_SIZE \ + model.train_ds.num_workers=$NUM_WORKERS \ + model.validation_ds.batch_size=$BATCH_SIZE \ + model.validation_ds.num_workers=$NUM_WORKERS \ + ~model.test_ds \ + trainer.limit_train_batches=$LIMIT_TRAIN_BATCHES \ + trainer.val_check_interval=$VAL_CHECK_INTERVAL \ + trainer.max_steps=$MAX_STEPS \ + exp_manager.name=$EXP_NAME +``` + +For lowest EOU latency, we set `att_context_size` to `[70,1]` in the model config file, which means the model lookahead is 1 frame (80ms), and the input chunk size is thus 2 frames (160ms). + + +## 4. Model evaluation + +After training, we can evaluate the model on the evaluation dataset by running the following command: +```bash +TEST_MANIFEST="[/path/to/your/test_manifest.json,/path/to/your/test_manifest2.json,...]" +TEST_NAME="[test_name1,test_name2,...]" +TEST_BATCH=32 +NUM_WORKERS=8 + +SAVE_PRED_TO_FILE=/path/to/predictions.json # optional, if you want to save the predictions to a file, will slow down the evaluation speed. Set to `null` to disable. +PRETRAINED_NEMO=/path/to/EOU/model.nemo +CONFIG_NAME=fastconformer_transducer_bpe_streaming_xlarge + +python speech_to_text_eou_eval.py \ + --config-name $CONFIG_NAME \ + ++save_pred_to_file=$SAVE_PRED_TO_FILE \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + ~model.train_ds \ + ~model.validation_ds \ + ++model.test_ds.defer_setup=true \ + ++model.test_ds.sample_rate=16000 \ + ++model.test_ds.manifest_filepath=$TEST_MANIFEST \ + ++model.test_ds.name=$TEST_NAME \ + ++model.test_ds.batch_size=$TEST_BATCH \ + ++model.test_ds.num_workers=$NUM_WORKERS \ + ++model.test_ds.drop_last=false \ + ++model.test_ds.force_finite=true \ + ++model.test_ds.shuffle=false \ + ++model.test_ds.pin_memory=true \ + exp_manager.create_wandb_logger=false +``` + +The script will show the WER metrics along with EOU metrics like latency, early cutoff rate, miss detection rate, etc. + + +## 5. Model deployment with voice agent + +Please refer to the [NeMo Voice Agent](https://github.com/NVIDIA-NeMo/NeMo/tree/main/examples/voice_agent/README.md) example for more details on how to deploy the ASR-EOU model with voice agent. + diff --git a/examples/asr/asr_eou/speech_to_text_eou_eval.py b/examples/asr/asr_eou/speech_to_text_eou_eval.py new file mode 100644 index 000000000000..17d00a385be4 --- /dev/null +++ b/examples/asr/asr_eou/speech_to_text_eou_eval.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +```bash +TEST_MANIFEST="[/path/to/your/test_manifest.json,/path/to/your/test_manifest2.json,...]" +TEST_NAME="[test_name1,test_name2,...]" +TEST_BATCH=32 +NUM_WORKERS=8 + +PRETRAINED_NEMO=/path/to/EOU/model.nemo +CONFIG_NAME=fastconformer_transducer_bpe_streaming + +python speech_to_text_eou_eval.py \ + --config-name $CONFIG_NAME \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + ~model.train_ds \ + ~model.validation_ds \ + ++model.test_ds.defer_setup=true \ + ++model.test_ds.sample_rate=16000 \ + ++model.test_ds.manifest_filepath=$TEST_MANIFEST \ + ++model.test_ds.name=$TEST_NAME \ + ++model.test_ds.batch_size=$TEST_BATCH \ + ++model.test_ds.num_workers=$NUM_WORKERS \ + ++model.test_ds.drop_last=false \ + ++model.test_ds.force_finite=true \ + ++model.test_ds.shuffle=false \ + ++model.test_ds.pin_memory=true \ + exp_manager.create_wandb_logger=false +``` + +""" + + +import lightning.pytorch as pl +import torch + +torch.set_float32_matmul_precision("highest") +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.models import ASRModel +from nemo.core.classes import typecheck +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + +typecheck.set_typecheck_enabled(False) + + +def load_model(cfg: DictConfig, trainer: pl.Trainer) -> ASRModel: + if "init_from_nemo_model" in cfg: + logging.info(f"Loading model from local file: {cfg.init_from_nemo_model}") + model = ASRModel.restore_from(cfg.init_from_nemo_model, trainer=trainer) + elif "init_from_pretrained_model" in cfg: + logging.info(f"Loading model from remote: {cfg.init_from_pretrained_model}") + model = ASRModel.from_pretrained(cfg.init_from_pretrained_model, trainer=trainer) + else: + raise ValueError( + "Please provide either 'init_from_nemo_model' or 'init_from_pretrained_model' in the config file." + ) + if cfg.get("init_from_ptl_ckpt", None): + logging.info(f"Loading weights from checkpoint: {cfg.init_from_ptl_ckpt}") + state_dict = torch.load(cfg.init_from_ptl_ckpt, map_location='cpu', weights_only=False)['state_dict'] + model.load_state_dict(state_dict, strict=True) + return model + + +@hydra_runner(config_path="../conf/asr_eou", config_name="fastconformer_transducer_bpe_streaming") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + + asr_model = load_model(cfg, trainer) + asr_model = asr_model.eval() # Set the model to evaluation mode + if hasattr(asr_model, 'wer'): + asr_model.wer.log_prediction = False + + with open_dict(asr_model.cfg): + if "save_pred_to_file" in cfg: + asr_model.cfg.save_pred_to_file = cfg.save_pred_to_file + if "calclate_eou_metrics" in cfg: + asr_model.cfg.calclate_eou_metrics = cfg.calclate_eou_metrics + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + with open_dict(cfg.model.test_ds): + cfg.model.test_ds.pad_eou_label_secs = asr_model.cfg.get('pad_eou_label_secs', 0.0) + asr_model.setup_multiple_test_data(test_data_config=cfg.model.test_ds) + trainer.test(asr_model) + else: + raise ValueError( + "No test dataset provided. Please provide a test dataset in the config file under model.test_ds." + ) + logging.info("Test completed.") + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py b/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py new file mode 100644 index 000000000000..f39a75123680 --- /dev/null +++ b/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py @@ -0,0 +1,350 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: + +1. Prepare dataset based on /nemo/collections/asr/data/audio_to_eou_label_lhotse.py + Specifically, each sample in the jsonl manifest should have the following fields: + { + "audio_filepath": "/path/to/audio.wav", + "text": "The text of the audio." + "offset": 0.0, # offset of the audio, in seconds + "duration": 3.0, # duration of the audio, in seconds + "sou_time": 0.2, # start of utterance time, in seconds + "eou_time": 1.5, # end of utterance time, in seconds + } + +2. If using a normal ASR model as initialization: + - Add special tokens and to the tokenizer of pretrained model, by refering to the script + /scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py + - If pretrained model is HybridRNNTCTCBPEModel, convert it to RNNT using the script + /examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py + +3. Run the following command to train the ASR-EOU model: +```bash +#!/bin/bash + +TRAIN_MANIFEST=/path/to/train_manifest.json +VAL_MANIFEST=/path/to/val_manifest.json +NOISE_MANIFEST=/path/to/noise_manifest.json + +PRETRAINED_NEMO=/path/to/pretrained_model.nemo +TOKENIZER_DIR=/path/to/tokenizer_dir + +BATCH_SIZE=16 +NUM_WORKERS=8 +LIMIT_TRAIN_BATCHES=1000 +VAL_CHECK_INTERVAL=1000 +MAX_STEPS=1000000 + +EXP_NAME=fastconformer_transducer_bpe_streaming_eou +SCRIPT=${NEMO_PATH}/examples/asr/asr_eou/speech_to_text_rnnt_eou_train.py +CONFIG_PATH=${NEMO_PATH}/examples/asr/conf/asr_eou +CONFIG_NAME=fastconformer_transducer_bpe_streaming + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT \ + --config-path $CONFIG_PATH \ + --config-name $CONFIG_NAME \ + ++init_from_nemo_model=$PRETRAINED_NEMO \ + model.encoder.att_context_size="[70,1]" \ + model.tokenizer.dir=$TOKENIZER_DIR \ + model.train_ds.manifest_filepath=$TRAIN_MANIFEST \ + model.train_ds.augmentor.noise.manifest_path=$NOISE_MANIFEST \ + model.validation_ds.manifest_filepath=$VAL_MANIFEST \ + model.train_ds.batch_size=$BATCH_SIZE \ + model.train_ds.num_workers=$NUM_WORKERS \ + model.validation_ds.batch_size=$BATCH_SIZE \ + model.validation_ds.num_workers=$NUM_WORKERS \ + ~model.test_ds \ + trainer.limit_train_batches=$LIMIT_TRAIN_BATCHES \ + trainer.val_check_interval=$VAL_CHECK_INTERVAL \ + trainer.max_steps=$MAX_STEPS \ + exp_manager.name=$EXP_NAME +``` + +""" + +from dataclasses import is_dataclass +from typing import Optional + +import lightning.pytorch as pl +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCBPEModel, EncDecRNNTBPEModel +from nemo.collections.asr.models.asr_eou_models import EncDecRNNTBPEEOUModel +from nemo.collections.asr.modules.rnnt import RNNTDecoder, RNNTJoint +from nemo.core import adapter_mixins +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.trainer_utils import resolve_trainer_cfg + + +def add_global_adapter_cfg(model, global_adapter_cfg): + # Convert to DictConfig from dict or Dataclass + if is_dataclass(global_adapter_cfg): + global_adapter_cfg = OmegaConf.structured(global_adapter_cfg) + + if not isinstance(global_adapter_cfg, DictConfig): + global_adapter_cfg = DictConfig(global_adapter_cfg) + + # Update the model.cfg with information about the new adapter global cfg + with open_dict(global_adapter_cfg), open_dict(model.cfg): + if 'adapters' not in model.cfg: + model.cfg.adapters = OmegaConf.create({}) + + # Add the global config for adapters to the model's internal config + model.cfg.adapters[model.adapter_global_cfg_key] = global_adapter_cfg + + # Update all adapter modules (that already exist) with this global adapter config + model.update_adapter_cfg(model.cfg.adapters) + + +def update_model_config_to_support_adapter(model_cfg): + with open_dict(model_cfg): + # Update encoder adapter compatible config + adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_) + if adapter_metadata is not None: + model_cfg.encoder._target_ = adapter_metadata.adapter_class_path + + +def setup_adapters(cfg: DictConfig, model: ASRModel): + # Setup adapters + with open_dict(cfg.model.adapter): + # Extract the name of the adapter (must be give for training) + adapter_name = cfg.model.adapter.pop("adapter_name") + adapter_type = cfg.model.adapter.pop("adapter_type") + adapter_module_name = cfg.model.adapter.pop("adapter_module_name", None) + + # Resolve the config of the specified `adapter_type` + if adapter_type not in cfg.model.adapter.keys(): + raise ValueError( + f"Adapter type ({adapter_type}) config could not be found. Adapter setup config - \n" + f"{OmegaConf.to_yaml(cfg.model.adapter)}" + ) + + adapter_type_cfg = cfg.model.adapter[adapter_type] + print(f"Found `{adapter_type}` config :\n" f"{OmegaConf.to_yaml(adapter_type_cfg)}") + + # Augment adapter name with module name, if not provided by user + if adapter_module_name is not None and ':' not in adapter_name: + adapter_name = f'{adapter_module_name}:{adapter_name}' + + # Extract the global adapter config, if provided + adapter_global_cfg = cfg.model.adapter.pop(model.adapter_global_cfg_key, None) + if adapter_global_cfg is not None: + add_global_adapter_cfg(model, adapter_global_cfg) + + model.add_adapter(adapter_name, cfg=adapter_type_cfg) + assert model.is_adapter_available() + + # Disable all other adapters, enable just the current adapter. + model.set_enabled_adapters(enabled=False) # disable all adapters prior to training + model.set_enabled_adapters(adapter_name, enabled=True) # enable just one adapter by name + + model.freeze() # freeze whole model by default + if not cfg.model.get("freeze_decoder", True): + logging.info("Unfreezing decoder weights.") + model.decoder.unfreeze() + if hasattr(model, 'joint') and not cfg.model.get(f"freeze_joint", True): + logging.info("Unfreezing joint network weights.") + model.joint.unfreeze() + + # Activate dropout() and other modules that depend on train mode. + model = model.train() + # Then, Unfreeze just the adapter weights that were enabled above (no part of encoder/decoder/joint/etc) + model.unfreeze_enabled_adapters() + return model + + +def get_pretrained_model_name(cfg: DictConfig) -> Optional[str]: + if hasattr(cfg, 'init_from_ptl_ckpt') and cfg.init_from_ptl_ckpt is not None: + raise NotImplementedError( + "Currently for simplicity of single script for all model types, we only support `init_from_nemo_model` and `init_from_pretrained_model`" + ) + nemo_model_path = cfg.get('init_from_nemo_model', None) + pretrained_name = cfg.get('init_from_pretrained_model', None) + if nemo_model_path is not None and pretrained_name is not None: + raise ValueError("Only pass `init_from_nemo_model` or `init_from_pretrained_model` but not both") + elif nemo_model_path is None and pretrained_name is None: + return None + + if nemo_model_path: + return nemo_model_path + if pretrained_name: + return pretrained_name + return None + + +def init_from_pretrained_nemo(model: EncDecRNNTBPEEOUModel, pretrained_model_path: str, cfg: DictConfig): + """ + Load the pretrained model from a .nemo file or remote checkpoint. If the pretrained model has exactly + the same vocabulary size as the current model, the whole model will be loaded directly. Otherwise, + the encoder and decoder weights will be loaded separately and the EOU/EOB classes will be handled separately. + """ + if pretrained_model_path.endswith('.nemo'): + pretrained_model = ASRModel.restore_from(restore_path=pretrained_model_path) # type: EncDecRNNTBPEModel + else: + pretrained_model = ASRModel.from_pretrained(pretrained_model_path) # type: EncDecRNNTBPEModel + + if not isinstance(pretrained_model, (EncDecRNNTBPEModel, EncDecHybridRNNTCTCBPEModel)): + raise TypeError( + f"Pretrained model {pretrained_model.__class__} is not EncDecRNNTBPEModel or EncDecHybridRNNTCTCBPEModel." + ) + + try: + model.load_state_dict(pretrained_model.state_dict(), strict=True) + logging.info( + f"Pretrained model from {pretrained_model_path} has exactly the same model structure, skip further loading." + ) + return + except Exception: + logging.warning( + f"Pretrained model {pretrained_model_path} has different model structure, try loading weights separately and add EOU/EOB classes." + ) + + # Load encoder state dict into the model + model.encoder.load_state_dict(pretrained_model.encoder.state_dict(), strict=True) + logging.info(f"Encoder weights loaded from {pretrained_model_path}.") + + # Load decoder state dict into the model + decoder = model.decoder # type: RNNTDecoder + pretrained_decoder = pretrained_model.decoder # type: RNNTDecoder + if not isinstance(decoder, RNNTDecoder) or not isinstance(pretrained_decoder, RNNTDecoder): + raise TypeError( + f"Decoder {decoder.__class__} is not RNNTDecoder or pretrained decoder {pretrained_decoder.__class__} is not RNNTDecoder." + ) + + decoder.prediction["dec_rnn"].load_state_dict(pretrained_decoder.prediction["dec_rnn"].state_dict(), strict=True) + + decoder_embed_states = decoder.prediction["embed"].state_dict()['weight'] # shape: [num_classes+2, hid_dim] + pretrained_decoder_embed_states = pretrained_decoder.prediction["embed"].state_dict()[ + 'weight' + ] # shape: [num_classes, hid_dim] + if decoder_embed_states.shape[0] != pretrained_decoder_embed_states.shape[0] + 2: + raise ValueError( + f"Size mismatched between pretrained ({pretrained_decoder_embed_states.shape[0]}+2) and current model ({decoder_embed_states.shape[0]}), skip loading decoder embedding." + ) + + decoder_embed_states[:-3, :] = pretrained_decoder_embed_states[:-1, :] # everything except EOU, EOB and blank + decoder_embed_states[-1, :] = pretrained_decoder_embed_states[-1, :] # blank class + decoder.prediction["embed"].load_state_dict({"weight": decoder_embed_states}, strict=True) + logging.info(f"Decoder weights loaded from {pretrained_model_path}.") + + # Load joint network weights if new model's joint network has two more classes than the pretrained model + joint_network = model.joint # type: RNNTJoint + pretrained_joint_network = pretrained_model.joint # type: RNNTJoint + assert isinstance(joint_network, RNNTJoint), f"Joint network {joint_network.__class__} is not RNNTJoint." + assert isinstance( + pretrained_joint_network, RNNTJoint + ), f"Pretrained joint network {pretrained_joint_network.__class__} is not RNNTJoint." + joint_network.pred.load_state_dict(pretrained_joint_network.pred.state_dict(), strict=True) + joint_network.enc.load_state_dict(pretrained_joint_network.enc.state_dict(), strict=True) + + if joint_network.num_classes_with_blank != pretrained_joint_network.num_classes_with_blank + 2: + raise ValueError( + f"Size mismatched between pretrained ({pretrained_joint_network.num_classes_with_blank}+2) and current model ({joint_network.num_classes_with_blank}), skip loading joint network." + ) + + # Load the joint network weights + pretrained_joint_state = pretrained_joint_network.joint_net.state_dict() + joint_state = joint_network.joint_net.state_dict() + pretrained_joint_clf_weight = pretrained_joint_state['2.weight'] # shape: [num_classes, hid_dim] + pretrained_joint_clf_bias = pretrained_joint_state['2.bias'] if '2.bias' in pretrained_joint_state else None + + token_init_method = cfg.model.get('token_init_method', 'constant') + # Copy the weights and biases from the pretrained model to the new model + # shape: [num_classes+2, hid_dim] + joint_state['2.weight'][:-3, :] = pretrained_joint_clf_weight[:-1, :] # everything except EOU, EOB and blank + joint_state['2.weight'][-1, :] = pretrained_joint_clf_weight[-1, :] # blank class + + value = None + if token_init_method == 'min': + # set the EOU and EOB class to the minimum value of the pretrained model + value = pretrained_joint_clf_weight.min(dim=0)[0] + elif token_init_method == 'max': + # set the EOU and EOB class to the maximum value of the pretrained model + value = pretrained_joint_clf_weight.max(dim=0)[0] + elif token_init_method == 'mean': + # set the EOU and EOB class to the mean value of the pretrained model + value = pretrained_joint_clf_weight.mean(dim=0) + elif token_init_method == 'constant': + value = cfg.model.get('token_init_weight_value', None) + elif token_init_method: + raise ValueError(f"Unknown token_init_method: {token_init_method}.") + + if value is not None: + joint_state['2.weight'][-2, :] = value # EOB class + joint_state['2.weight'][-3, :] = value # EOU class + + if pretrained_joint_clf_bias is not None and '2.bias' in joint_state: + joint_state['2.bias'][:-3] = pretrained_joint_clf_bias[:-1] # everything except EOU, EOB and blank + joint_state['2.bias'][-1] = pretrained_joint_clf_bias[-1] # blank class + value = None + if token_init_method == 'constant': + value = cfg.model.get('token_init_bias_value', None) + elif token_init_method == 'min': + # set the EOU and EOB class to the minimum value of the pretrained model + value = pretrained_joint_clf_bias.min() + elif token_init_method == 'max': + # set the EOU and EOB class to the maximum value of the pretrained model + value = pretrained_joint_clf_bias.max() + elif token_init_method == 'mean': + # set the EOU and EOB class to the mean value of the pretrained model + value = pretrained_joint_clf_bias.mean() + elif token_init_method: + raise ValueError(f"Unknown token_init_method: {token_init_method}.") + + if value is not None: + joint_state['2.bias'][-2] = value # EOB class + joint_state['2.bias'][-3] = value # EOU class + + # Load the joint network weights + joint_network.joint_net.load_state_dict(joint_state, strict=True) + logging.info(f"Joint network weights loaded from {pretrained_model_path}.") + + +@hydra_runner(config_path="../conf/asr_eou", config_name="fastconformer_transducer_bpe_streaming") +def main(cfg): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer)) + exp_manager(trainer, cfg.get("exp_manager", None)) + + if cfg.model.get("adapter", None) is not None: + update_model_config_to_support_adapter(cfg.model) + + asr_model = EncDecRNNTBPEEOUModel(cfg=cfg.model, trainer=trainer) + + init_from_model = get_pretrained_model_name(cfg) + if init_from_model: + init_from_pretrained_nemo(asr_model, init_from_model, cfg) + + if cfg.model.get("freeze_encoder", False): + logging.info("Freezing encoder weights.") + asr_model.encoder.freeze() + + if cfg.model.get("adapter", None) is not None: + asr_model = setup_adapters(cfg, asr_model) + + trainer.fit(asr_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if asr_model.prepare_test(trainer): + trainer.test(asr_model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py index 199e399ead11..34afa8309084 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py +++ b/examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py @@ -20,7 +20,7 @@ in NeMo. The resulting .nemo file will be a pure CTC or RNNT model, and can be used like any other .nemo model including in nemo2riva. -Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -m ctc|rnnt +Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -t ctc|rnnt """ diff --git a/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml new file mode 100644 index 000000000000..d839ccd14ad6 --- /dev/null +++ b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_adapter.yaml @@ -0,0 +1,375 @@ +# It contains the default values for training a cache-aware streaming FastConformer-Transducer ASR+EOU model, large size (~120M) with sub-word encoding. +# This config adds trainable adapters to the frozen conformer encoder, and fully finetune the rnnt decoder. + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer +# FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml + +name: "FastConformer-Transducer-BPE-Streaming-EOU-adapter" + +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' + + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + adapter: + ### Config of the adapter training/eval script ### + adapter_name: "eou-adapter" # Name of the adapter, used by the script + adapter_type: "linear" # Type of the adapter. Corresponds to the subconfigs below. + adapter_module_name: null # Name of the adapter module. Combine multiple modules with '+' between module names. + adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this. + + ### Adapter Configs ### + # Linear / Houlsby Adapter (https://arxiv.org/abs/1902.00751) + linear: + # Config of the adapter module itself + _target_: nemo.collections.common.parts.adapter_modules.LinearAdapter + in_features: ${model.encoder.d_model} # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. + dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count. + activation: swish + norm_position: 'pre' # Can be `pre` or `post` + dropout: 0.0 # float, dropout for the adapter + + # Adapter strategy config + adapter_strategy: + _target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy + stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. + l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. + + # Tiny-Attention Adapter (https://arxiv.org/abs/2211.01979) + # NOTE: Only supported for Attention based encoders. Make sure to pass `adapter_module_name` as "encoder" + tiny_attn: + # Config of the adapter module itself + # Defaults to Relative Positional Encoding MHA + # _target_ can instead be .MultiHeadAttentionAdapter if Conformer was originally using Absolute Positional Encoding. + _target_: nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.RelPositionMultiHeadAttentionAdapter + n_feat: ${model.encoder.d_model} # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. + n_head: 1 # Number of heads for attention. + proj_dim: -1 # Can be `null` - to avoid projection, > 0 for explicit dim, or -1 to default to `n_head` + dropout_rate: 0.0 # float, dropout for the adapter + + # Adapter strategy config + adapter_strategy: + _target_: nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module.MHAResidualAddAdapterStrategy + stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. + l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. + + # Optional global config available to all adapters at a global level. + # A global config is shared across every layer of the adapters, defining global properties rather + # than properties local to the adapter (as defined above). + # This can be useful in order to select *which type of adapter* is added, *what adapters to enable*, + # and further global operations that can decide dynamically how to support the requested adapter. + global_cfg: + check_encoder_adapter: True # ASR adapter key, determines whether to check if encoder adapter modules is supported + check_decoder_adapter: False # ASR adapter key, determines whether to check if decoder adapter modules is supported + check_joint_adapter: False # ASR adapter key, determines whether to check if joint adapter modules is supported + + freeze_encoder: True + freeze_decoder: False + freeze_joint: False + + train_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + random_padding: + prob: 0.99 + min_pad_duration: 1.0 # minimum duration of pre/post padding in seconds + max_pad_duration: 10.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + pre_pad_duration: 0.2 # amount of pre-padding when pad_distribution='constant' + post_pad_duration: 3.0 # amount of post-padding when pad_distribution='constant' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + test_ds: + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + # We recommend to use vocab size of 1024 with SPE Unigram for most languages + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] + att_context_size: [70, 1] # -1 means unlimited context + att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + + xscaling: false # scales up the input embeddings by sqrt(d_model) + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + # Recommend to use causal convolutions as it would increase the effective right context and therefore the look-ahead significantly + conv_context_size: causal + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + optim: + name: adamw + lr: 5.0 # 1e-4 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing # NoamAnnealing CosineAnnealing + # scheduler config override + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1000 # an int for number of iterations + limit_train_batches: ${trainer.val_check_interval} + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: false + resume_ignore_no_checkpoint: false + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_large.yaml b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_large.yaml new file mode 100644 index 000000000000..5c46c9d97130 --- /dev/null +++ b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_large.yaml @@ -0,0 +1,338 @@ +# It contains the default values for training a cache-aware streaming FastConformer-Transducer ASR+EOU model, large size (~120M) with sub-word encoding. +# Here are the recommended configs for different variants of FastConformer-Transducer-BPE, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Model | d_model | n_heads | n_layers |conv_kernel_size| weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | xscaling | +# +==============+=========+========+===========+================+==============+==========================+=================+============+ +# | Small (14M) | 176 | 4 | 16 | 9 | 0.0 | 320 | 1 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Medium (32M) | 256 | 4 | 16 | 9 | 1e-3 | 640 | 1 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Large (120M) | 512 | 8 | 17 | 9 | 1e-3 | 640 | 1 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | XLarge (616M)| 1024 | 8 | 24 | 9 | 1e-3 | 640 | 2 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | XXLarge(1.2B)| 1024 | 8 | 42 | 5 | 1e-3 | 640 | 2 | False | +# +--------------------------------------------------------------+--------------+--------------------------+-----------------+------------+ + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer +# FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml + +name: "FastConformer-Transducer-Large-BPE-Streaming-EOU" + +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' + + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + random_padding: + prob: 0.99 + min_post_pad_duration: 3.0 + min_pre_pad_duration: 0.0 + max_pad_duration: 6.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + pre_pad_duration: 0.2 # amount of pre-padding when pad_distribution='constant' + post_pad_duration: 3.0 # amount of post-padding when pad_distribution='constant' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: null + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + test_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + # We recommend to use vocab size of 1024 with SPE Unigram for most languages + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] + att_context_size: [70, 1] # -1 means unlimited context + att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + + xscaling: false # scales up the input embeddings by sqrt(d_model) + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + # Recommend to use causal convolutions as it would increase the effective right context and therefore the look-ahead significantly + conv_context_size: causal + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 1 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + optim: + name: adamw + lr: 5.0 # + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + # scheduler config override + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1000 # an int for number of iterations + limit_train_batches: ${trainer.val_check_interval} + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 3 + filename: '${exp_manager.name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: true + resume_ignore_no_checkpoint: true + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml new file mode 100644 index 000000000000..c0b05f5bad10 --- /dev/null +++ b/examples/asr/conf/asr_eou/fastconformer_transducer_bpe_streaming_xlarge.yaml @@ -0,0 +1,338 @@ +# It contains the default values for training a cache-aware streaming FastConformer-Transducer ASR+EOU model, Xlarge size (~600M) with sub-word encoding. +# Here are the recommended configs for different variants of FastConformer-Transducer-BPE, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Model | d_model | n_heads | n_layers |conv_kernel_size| weight_decay | pred_hidden/joint_hidden | pred_rnn_layers | xscaling | +# +==============+=========+========+===========+================+==============+==========================+=================+============+ +# | Small (14M) | 176 | 4 | 16 | 9 | 0.0 | 320 | 1 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Medium (32M) | 256 | 4 | 16 | 9 | 1e-3 | 640 | 1 | True | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | Large (120M) | 512 | 8 | 17 | 9 | 1e-3 | 640 | 1 | False | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | XLarge (616M)| 1024 | 8 | 24 | 9 | 1e-3 | 640 | 2 | False | +# +--------------+---------+--------+-----------+----------------+--------------+--------------------------+-----------------+------------+ +# | XXLarge(1.2B)| 1024 | 8 | 42 | 5 | 1e-3 | 640 | 2 | False | +# +--------------------------------------------------------------+--------------+--------------------------+-----------------+------------+ + +# You may find more detail: +# FastConformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#fast-conformer +# Cache-aware Conformer here: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/asr/models.html#cache-aware-streaming-conformer +# FastConformer-Transducer's architecture config, along with the optimal batch size and precision: NeMo/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml + +name: "FastConformer-Transducer-XLarge-BPE-Streaming-EOU" + +model: + token_init_method: "constant" # choices=['min', 'max', 'mean', 'constant'] + token_init_weight_value: null # only applicable when token_init_method='constant' + token_init_bias_value: -1000.0 # only applicable when token_init_method='constant' + + sample_rate: 16000 + compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + model_defaults: + enc_hidden: ${model.encoder.d_model} + pred_hidden: 640 + joint_hidden: 640 + + train_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: true + drop_last: true + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + random_padding: + prob: 0.99 + min_post_pad_duration: 3.0 + min_pre_pad_duration: 0.0 + max_pad_duration: 6.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: 0.5 # mean of normal distribution used when pad_distribution='normal' + normal_std: 2.0 # standard deviation of normal distribution used when pad_distribution='normal' + pre_pad_duration: 0.2 # amount of pre-padding when pad_distribution='constant' + post_pad_duration: 3.0 # amount of post-padding when pad_distribution='constant' + + augmentor: + white_noise: + prob: 0.9 + min_level: -90 + max_level: -46 + gain: + prob: 0.2 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 0.9 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 + + validation_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + test_ds: + input_cfg: null + manifest_filepath: null + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + defer_setup: true + batch_duration: null # you may disable batch_duration by setting it to `null` + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + ignore_eob_label: true # ignore backchannel and treat them the same as EOU + + # You may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + # We recommend to use vocab size of 1024 with SPE Unigram for most languages + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "NA" # No normalization for mel-spectogram makes streaming easier + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 128 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 24 + d_model: 1024 + use_bias: false # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: true + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + # for att_context_style=regular, the right context is recommended to be a small number around 0 to 3 as multiple-layers may increase the effective right context too large + # for att_context_style=chunked_limited, the left context need to be dividable by the right context plus one + # look-ahead(secs) = att_context_size[1]*subsampling_factor*window_stride, example: 13*8*0.01=1.04s + + # For multi-lookahead models, you may specify a list of context sizes. During the training, different context sizes would be used randomly with the distribution specified by att_context_probs. + # The first item in the list would be the default during test/validation/inference. + # An example of settings for multi-lookahead: + # att_context_size: [[70,13],[70,6],[70,1],[70,0]] + # att_context_probs: [0.25, 0.25, 0.25, 0.25, 0.25] + att_context_size: [70, 1] # -1 means unlimited context + att_context_style: chunked_limited # regular or chunked_limited + att_context_probs: null + + xscaling: false # scales up the input embeddings by sqrt(d_model) + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'layer_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + # Recommend to use causal convolutions as it would increase the effective right context and therefore the look-ahead significantly + conv_context_size: causal + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.RNNTDecoder + normalization_mode: null # Currently only null is supported for export. + random_state_sampling: false # Random state sampling: https://arxiv.org/pdf/1910.11455.pdf + blank_as_pad: true # This flag must be set in order to support exporting of RNNT models + efficient inference. + + prednet: + pred_hidden: ${model.model_defaults.pred_hidden} + pred_rnn_layers: 2 + t_max: null + dropout: 0.2 + + joint: + _target_: nemo.collections.asr.modules.RNNTJoint + log_softmax: null # 'null' would set it automatically according to CPU/GPU device + preserve_memory: false # dramatically slows down training, but might preserve some memory + + # Fuses the computation of prediction net + joint net + loss + WER calculation + # to be run on sub-batches of size `fused_batch_size`. + # When this flag is set to true, consider the `batch_size` of *_ds to be just `encoder` batch size. + # `fused_batch_size` is the actual batch size of the prediction net, joint net and transducer loss. + # Using small values here will preserve a lot of memory during training, but will make training slower as well. + # An optimal ratio of fused_batch_size : *_ds.batch_size is 1:1. + # However, to preserve memory, this ratio can be 1:8 or even 1:16. + # Extreme case of 1:B (i.e. fused_batch_size=1) should be avoided as training speed would be very slow. + fuse_loss_wer: true + fused_batch_size: 4 + + jointnet: + joint_hidden: ${model.model_defaults.joint_hidden} + activation: "relu" + dropout: 0.2 + + decoding: + strategy: "greedy_batch" # can be greedy, greedy_batch, beam, tsd, alsd. + + # greedy strategy config + greedy: + max_symbols: 10 + + # beam strategy config + beam: + beam_size: 2 + return_best_hypothesis: False + score_norm: true + tsd_max_sym_exp: 50 # for Time Synchronous Decoding + alsd_max_target_len: 2.0 # for Alignment-Length Synchronous Decoding + + # config for InterCTC loss: https://arxiv.org/abs/2102.03216 + # specify loss weights and which layers to use for InterCTC + # e.g., to reproduce the paper results, set loss_weights: [0.3] + # and apply_at_layers: [8] (assuming 18 layers). Note that final + # layer loss coefficient is automatically adjusted (to 0.7 in above example) + interctc: + loss_weights: [] + apply_at_layers: [] + + loss: + loss_name: "default" + warprnnt_numba_kwargs: + # FastEmit regularization: https://arxiv.org/abs/2010.11148 + # You may enable FastEmit to increase the accuracy and reduce the latency of the model for streaming + # You may set it to lower values like 1e-3 for models with larger right context + fastemit_lambda: 3e-2 # Recommended values to be in range [1e-4, 1e-2], 0.001 is a good start. + clamp: -1.0 # if > 0, applies gradient clamping in range [-clamp, clamp] for the joint tensor only. + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + # scheduler config override + d_model: ${model.encoder.d_model} + warmup_steps: 10000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: 100000 # computed at runtime if not set + val_check_interval: 1000 # an int for number of iterations + limit_train_batches: ${trainer.val_check_interval} + accelerator: auto + strategy: + _target_: lightning.pytorch.strategies.DDPStrategy + gradient_as_bucket_view: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + use_distributed_sampler: false + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 3 + filename: '${exp_manager.name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}' + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: true + resume_ignore_no_checkpoint: true + + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 2cdc3a30b96d..abca4f374656 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -302,7 +302,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') - if cfg.decoder_type and hasattr(asr_model.encoder, 'set_default_att_context_size'): + if cfg.att_context_size and hasattr(asr_model.encoder, 'set_default_att_context_size'): asr_model.encoder.set_default_att_context_size(cfg.att_context_size) # Setup decoding strategy diff --git a/nemo/collections/asr/data/audio_to_eou_label_lhotse.py b/nemo/collections/asr/data/audio_to_eou_label_lhotse.py new file mode 100644 index 000000000000..725ccd994f04 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_eou_label_lhotse.py @@ -0,0 +1,524 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Dict, List, Optional + +import numpy as np +import torch.utils.data +from lhotse.cut import Cut, CutSet, MixedCut +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + +NON_SPEECH_LABEL = 0 +SPEECH_LABEL = 1 +EOU_LABEL = 2 +EOB_LABEL = 3 +EOU_STRING = '' +EOB_STRING = '' + +# These augmentations are not supported yet, since they will need to change the SOU/EOU timestamps +EOU_INVALID_AUGMENTATIONS = ['random_segment', 'speed', 'time_stretch'] + + +@dataclass +class AudioToTextEOUBatch: + """ + Data class for ASR-EOU batch. + """ + + sample_ids: List | None = None + audio_filepaths: List | None = None + audio_signal: torch.Tensor | None = None + audio_lengths: torch.Tensor | None = None + text_tokens: torch.Tensor | None = None + text_token_lengths: torch.Tensor | None = None + eou_targets: torch.Tensor | None = None + eou_target_lengths: torch.Tensor | None = None + + +@dataclass +class RandomPaddingConfig: + prob: float = 0.9 # probability of applying padding + min_pad_duration: float = 0.0 # minimum duration of pre/post padding in seconds + max_pad_duration: float = 5.0 # maximum duration of pre/post padding in seconds + max_total_duration: float = 40.0 # maximum total duration of the padded audio in seconds + min_pre_pad_duration: float = 0.0 # minimum duration of pre-padding in seconds + min_post_pad_duration: float = 2.0 # minimum duration of post-padding in seconds + pad_distribution: str = 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: float = 0.5 # mean of normal distribution for padding duration + normal_std: float = 2.0 # standard deviation of normal distribution for padding duration + pre_pad_duration: float = 0.2 # amount of left-padding when pad_distribution='constant' + post_pad_duration: float = 3.0 # amount of right-padding when pad_distribution='constant' + + +class LhotseSpeechToTextBpeEOUDataset(torch.utils.data.Dataset): + """ + This dataset processes the audio data and the corresponding text data to generate the ASR labels, + along with EOU labels for each frame. The audios used in this dataset should only contain speech with + NO precedding or following silence. The dataset also randomly pads non-speech frames before and after + the audio signal for training EOU prediction task. + + To generate EOU labels, the last frame of utterance will be marked as "end of utterance" (labeled as `2`), + while if it's a backchannel utterance it'll be marked asd "end of backchannel" (labeled as `3`). + The rest of the speech frames will be marked as "speech" (labeled as `1`). + The padded non-speech signals will be marked as "non-speech" (labeled as 0). + + Args: + cfg: DictConfig object container following keys, usually taken from your `model.train_ds` + or `model.validation_ds` config: + ``` + sample_rate: # int, Sample rate of the audio signal + window_stride: # float, Window stride for audio encoder + subsampling_factor: # Subsampling factor for audio encoder + random_padding: # Random padding configuration + prob: 0.9 # probability of applying padding + min_pad_duration: 0.5 # minimum duration of pre/post padding in seconds + max_pad_duration: 2.0 # maximum duration of pre/post padding in seconds + max_total_duration: 30.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'uniform' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + normal_mean: 0.5 # mean of normal distribution for padding duration + normal_std: 2.0 # standard deviation of normal distribution for padding duration + pre_pad_duration: 0.2 # amount of left-padding when pad_distribution='constant' + post_pad_duration: 3.0 # amount of right-padding when pad_distribution='constant' + ``` + + Returns: + audio: torch.Tensor of audio signal + audio_lens: torch.Tensor of audio signal length + text_tokens: torch.Tensor of text text_tokens + text_token_lens: torch.Tensor of text token length + eou_targets (optional): torch.Tensor of EOU labels + eou_target_lens (optional): torch.Tensor of EOU label length + + The input manifest should be a jsonl file where each line is a python dictionary. + Example manifest sample: + { + "audio_filepath": "/path/to/audio.wav", + "offset": 0.0, + "duration": 6.0, + "sou_time": [0.3, 4.0], + "eou_time": [1.3, 4.5], + "utterances": ["Tell me a joke", "Ah-ha"], + "is_backchannel": [False, True], + } + + Padding logic: + 0. Don't pad when `random_padding` is None or during validation/test + 1. randomly draw a probability to decide whether to apply padding + 2. if not padding or audio duration is longer than the maximum duration, + 1) return the original audio and EOU labels + 3. if apply padding, + 1) get the max padding duration based on the maximum total duration and the audio duration + 2) randomly draw a total padding duration based on the given distribution + 3) randomly split the total padding duration into pre-padding and post-padding + 4) randomly generate the non-speech signal (audio signal=0) for pre-padding and post-padding + 5) concatenate the pre-padding, audio, and post-padding to get the padded audio signal + 6) update the EOU labels accordingly + + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset.""" + return { + 'audio': NeuralType(('B', 'T'), AudioSignal()), + 'audio_lens': NeuralType(tuple('B'), LengthsType()), + 'eou_targets': NeuralType(('B', 'T'), LabelsType()), + 'eou_target_lens': NeuralType(tuple('B'), LengthsType()), + 'text_tokens': NeuralType(tuple('B', 'T'), LengthsType(), optional=True), + 'text_token_lens': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg: DictConfig, tokenizer: TokenizerSpec, return_cuts: bool = False): + super().__init__() + self.cfg = cfg + self.return_cuts = return_cuts + self.eou_string = self.cfg.get('eou_string', EOU_STRING) + self.eob_string = self.cfg.get('eob_string', EOB_STRING) + if cfg.get('check_tokenizer', True): + self._check_special_tokens(tokenizer) + + self.tokenizer = TokenizerWrapper(tokenizer) + self.load_audio = AudioSamples(fault_tolerant=True) + self.sample_rate = self.cfg.get('sample_rate', 16000) + self.window_stride = self.cfg.get('window_stride', 0.01) + self.num_sample_per_mel_frame = int( + self.window_stride * self.sample_rate + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.add_sep_before_eou = self.cfg.get('add_sep_before_eou', False) + self.add_eou_to_text = self.cfg.get('add_eou_to_text', True) + self.pad_eou_label_secs = self.cfg.get('pad_eou_label_secs', 0.0) + self.padding_cfg = self.cfg.get('random_padding', None) + if self.padding_cfg is not None: + self.padding_cfg = OmegaConf.to_container(self.padding_cfg, resolve=True) + self.padding_cfg = RandomPaddingConfig(**self.padding_cfg) + self.ignore_eob_label = self.cfg.get('ignore_eob_label', False) + self.augmentor = None + if self.cfg.get('augmentor', None) is not None: + augmentor = {} + aug_cfg = OmegaConf.to_container(self.cfg.augmentor, resolve=True) + for k, v in aug_cfg.items(): + if k in EOU_INVALID_AUGMENTATIONS: + logging.warning(f"EOU dataset does not support {k} augmentation yet, skipping.") + continue + augmentor[k] = v + + if len(augmentor) > 0: + logging.info(f"EOU dataset will apply augmentations: {augmentor}") + self.augmentor = process_augmentations(augmentor) + + def _check_special_tokens(self, tokenizer: TokenizerSpec): + """ + Check if the special tokens are in the tokenizer vocab. + """ + special_tokens = set([self.eou_string, self.eob_string]) + vocab_size = tokenizer.vocab_size + special_tokens_in_vocab = set([tokenizer.ids_to_text(vocab_size - 1), tokenizer.ids_to_text(vocab_size - 2)]) + if special_tokens != special_tokens_in_vocab: + raise ValueError( + f"Input special tokens {special_tokens} don't match with the tokenizer vocab {special_tokens_in_vocab}. " + f"Please add them to tokenizer or change input `eou_string` and/or `eob_string` accordingly. " + "Special tokens should be added as the last two tokens in the new tokenizer. " + "Please refer to scripts/asr_end_of_utterance/tokenizers/add_special_tokens_to_sentencepiece.py for details." + ) + + def __getitem__(self, cuts: CutSet) -> AudioToTextEOUBatch: + audio, audio_lens, cuts = self.load_audio(cuts) + audio_signals = [] + audio_lengths = [] + eou_targets = [] + text_tokens = [] + sample_ids = [] + audio_filepaths = [] + + for i in range(len(cuts)): + c = cuts[i] + if isinstance(c, MixedCut): + c = c.first_non_padding_cut + + sample_ids.append(c.id) + audio_filepaths.append(c.recording.sources[0].source) + + audio_i = audio[i] + audio_len_i = audio_lens[i] + + # Get EOU labels and text tokens + eou_targets_i = self._get_frame_labels(c, audio_len_i) + text_tokens_i = self._get_text_tokens(c) + + # Maybe apply random padding to both sides of the audio + audio_i, audio_len_i, eou_targets_i = self._random_pad_audio(audio_i, audio_len_i, eou_targets_i) + + # Maybe apply augmentations to the audio signal after padding + audio_i, audio_len_i = self._maybe_augment_audio(audio_i, audio_len_i) + + # Append the processed audio, EOU labels, and text tokens to the lists + audio_signals.append(audio_i) + audio_lengths.append(audio_len_i) + eou_targets.append(eou_targets_i) + text_tokens.append(text_tokens_i) + + audio_signals = collate_vectors(audio_signals, padding_value=0) + audio_lengths = torch.tensor(audio_lengths, dtype=torch.long) + eou_target_lens = torch.tensor([t.size(0) for t in eou_targets], dtype=torch.long) + eou_targets = collate_vectors(eou_targets, padding_value=0) + text_token_lens = torch.tensor([t.size(0) for t in text_tokens], dtype=torch.long) + text_tokens = collate_vectors(text_tokens, padding_value=0) + + if self.return_cuts: + return audio_signals, audio_lengths, cuts + + return AudioToTextEOUBatch( + sample_ids=sample_ids, + audio_filepaths=audio_filepaths, + audio_signal=audio_signals, + audio_lengths=audio_lengths, + text_tokens=text_tokens, + text_token_lengths=text_token_lens, + eou_targets=eou_targets, + eou_target_lengths=eou_target_lens, + ) + + def _audio_len_to_frame_len(self, num_samples: int): + """ + Convert the raw audio length to the number of frames after audio encoder. + + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + """ + mel_frame_count = math.ceil((num_samples + 1) / self.num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / self.num_mel_frame_per_target_frame) + return hidden_length + + def _repeat_eou_labels(self, eou_targets: torch.Tensor) -> torch.Tensor: + """ + Repeat EOU labels according to self.pad_eou_label_secs + Args: + eou_targets: torch.Tensor of EOU labels, shape [T] + Returns: + eou_targets: torch.Tensor of padded EOU labels, shape [T] + """ + if not self.pad_eou_label_secs or self.pad_eou_label_secs <= 0: + return eou_targets + + eou_len = self._audio_len_to_frame_len(int(self.pad_eou_label_secs * self.sample_rate)) + + i = 0 + while i < eou_targets.size(0): + if eou_targets[i] == EOU_LABEL or eou_targets[i] == EOB_LABEL: + # repeat the label for the next eou_len samples + start = i + end = min(i + eou_len, eou_targets.size(0)) + j = start + 1 + while j < end: + if eou_targets[j] != NON_SPEECH_LABEL: + # do not overwrite the label if it's not non-speech + break + j += 1 + end = min(j, end) + # fill the non-speech label with the current EOU/EOB label + eou_targets[start:end] = eou_targets[i] + i = end + else: + i += 1 + return eou_targets + + def _get_frame_labels(self, cut: Cut, num_samples: int): + """ + Get the frame-level EOU labels for a single audio segment. + Args: + cut: Cut object + num_samples: int, the number of samples in the audio segment + Returns: + eou_targets: torch.Tensor of EOU labels, shape [T] + """ + hidden_length = self._audio_len_to_frame_len(num_samples) + if not "sou_time" in cut.custom or not "eou_time" in cut.custom: + # assume only single speech segment + text = cut.supervisions[0].text + if not text: + # skip empty utterances + return torch.zeros(hidden_length).long() + eou_targets = torch.ones(hidden_length).long() # speech label + eou_targets[-1] = EOU_LABEL # by default it's end of utterance + if cut.has_custom("is_backchannel") and cut.custom["is_backchannel"] and not self.ignore_eob_label: + eou_targets[-1] = EOB_LABEL # end of backchannel + return eou_targets + + sou_time = cut.custom["sou_time"] + eou_time = cut.custom["eou_time"] + if not isinstance(sou_time, list): + sou_time = [sou_time] + if not isinstance(eou_time, list): + eou_time = [eou_time] + + assert len(sou_time) == len( + eou_time + ), f"Number of SOU time and EOU time do not match: SOU ({sou_time}) vs EOU ({eou_time})" + + if cut.has_custom("is_backchannel"): + is_backchannel = cut.custom["is_backchannel"] + if not isinstance(is_backchannel, list): + is_backchannel = [is_backchannel] + assert len(sou_time) == len( + is_backchannel + ), f"Number of SOU and backchannel do not match: SOU ({len(sou_time)}) vs backchannel ({len(is_backchannel)})" + else: + is_backchannel = [False] * len(sou_time) + + eou_targets = torch.zeros(hidden_length).long() + for i in range(len(sou_time)): + if sou_time[i] is None or eou_time[i] is None or sou_time[i] < 0 or eou_time[i] < 0: + # skip empty utterances + continue + sou_idx = self._audio_len_to_frame_len(int((sou_time[i] - cut.start) * self.sample_rate)) + seg_len_in_secs = eou_time[i] - sou_time[i] + seg_len = self._audio_len_to_frame_len(int(seg_len_in_secs * self.sample_rate)) + eou_targets[sou_idx : sou_idx + seg_len] = SPEECH_LABEL + last_idx = min(sou_idx + seg_len - 1, hidden_length - 1) + if is_backchannel[i] and not self.ignore_eob_label: + eou_targets[last_idx] = EOB_LABEL # end of backchannel + else: + eou_targets[last_idx] = EOU_LABEL # end of utterance + + return eou_targets + + def _get_text_tokens(self, cut: Cut): + """ + Add EOU labels to the text and get the text tokens for a single audio segment. + Args: + cut: Cut object + Returns: + text_tokens: torch.Tensor of text tokens, shape [T] + """ + if not cut.has_custom("sou_time") or not cut.has_custom("eou_time") or not cut.has_custom("utterances"): + # assume only single speech segment + utterances = [cut.supervisions[0].text] + else: + utterances = cut.custom["utterances"] + + if not isinstance(utterances, list): + utterances = [utterances] + + if cut.has_custom("is_backchannel"): + is_backchannel = cut.custom["is_backchannel"] + if not isinstance(is_backchannel, list): + is_backchannel = [is_backchannel] + assert len(utterances) == len( + is_backchannel + ), f"Number of utterances and backchannel do not match: utterance ({len(utterances)}) vs backchannel ({len(is_backchannel)})" + else: + is_backchannel = [False] * len(utterances) + + total_text = "" + for i, text in enumerate(utterances): + if not text: + # skip empty utterances + continue + if self.add_eou_to_text: + eou_string = self.eob_string if is_backchannel[i] and not self.ignore_eob_label else self.eou_string + if self.add_sep_before_eou: + eou_string = " " + eou_string + else: + eou_string = "" + total_text += text + eou_string + " " + total_text = total_text.strip() + return torch.as_tensor(self.tokenizer(total_text)) + + def _random_pad_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, eou_targets: torch.Tensor): + """ + Randomly pad the audio signal with non-speech signal before and after the audio signal. + Args: + audio: torch.Tensor of a single audio signal, shape [T] + audio_len: torch.Tensor of audio signal length, shape [1] + eou_targets: torch.Tensor of EOU labels, shape [T] + Returns: + padded_audio: torch.Tensor of padded audio signal, shape [T+padding] + padded_audio_len: torch.Tensor of padded audio signal length, shape [1] + padded_eou_targets: torch.Tensor of padded EOU labels, shape [T+padding] + padded_eou_targets_len: torch.Tensor of padded EOU label length, shape [1] + """ + p = np.random.rand() + if self.padding_cfg is None or p > self.padding_cfg.prob: + # don't apply padding + eou_targets = self._repeat_eou_labels(eou_targets) + return audio, audio_len, eou_targets + + duration = audio_len.item() / self.cfg.sample_rate + # if already longer than the maximum duration, return the original audio + if duration >= self.padding_cfg.max_total_duration: + return audio, audio_len, eou_targets + + # apply padding + audio = audio[:audio_len] + + self.padding_cfg.min_pre_pad_duration = max( + self.padding_cfg.min_pre_pad_duration, self.padding_cfg.min_pad_duration + ) + self.padding_cfg.min_post_pad_duration = max( + self.padding_cfg.min_post_pad_duration, self.padding_cfg.min_pad_duration + ) + + max_padding_duration = max(0, self.padding_cfg.max_total_duration - duration) + if max_padding_duration <= self.padding_cfg.min_pre_pad_duration + self.padding_cfg.min_post_pad_duration: + min_padding_duration = 0 + else: + min_padding_duration = self.padding_cfg.min_pre_pad_duration + self.padding_cfg.min_post_pad_duration + + pre_padding_duration = None + post_padding_duration = None + + if self.padding_cfg.pad_distribution == 'uniform': + total_padding_duration = np.random.uniform(min_padding_duration, max_padding_duration) + elif self.padding_cfg.pad_distribution == 'normal': + total_padding_duration = np.random.normal(self.padding_cfg.normal_mean, self.padding_cfg.normal_std) + total_padding_duration = max(min_padding_duration, min(max_padding_duration, total_padding_duration)) + elif self.padding_cfg.pad_distribution == 'constant': + pass + else: + raise ValueError(f"Unknown padding distribution: {self.padding_cfg.pad_distribution}") + + if self.padding_cfg.pad_distribution == 'constant': + pre_padding_duration = self.padding_cfg.pre_pad_duration + post_padding_duration = self.padding_cfg.post_pad_duration + elif min_padding_duration == 0: + pre_padding_duration = total_padding_duration / 2 + post_padding_duration = total_padding_duration / 2 + else: + post_padding_duration = np.random.uniform( + self.padding_cfg.min_post_pad_duration, total_padding_duration - self.padding_cfg.min_pre_pad_duration + ) + pre_padding_duration = total_padding_duration - post_padding_duration + + if self.padding_cfg.max_pad_duration is not None: + pre_padding_duration = min(pre_padding_duration, self.padding_cfg.max_pad_duration) + post_padding_duration = min(post_padding_duration, self.padding_cfg.max_pad_duration) + + pre_padding_len = math.ceil(pre_padding_duration * self.cfg.sample_rate) + post_padding_len = math.ceil(post_padding_duration * self.cfg.sample_rate) + + # pad the audio signal + pre_padding = torch.zeros(pre_padding_len, dtype=audio.dtype) + post_padding = torch.zeros(post_padding_len, dtype=audio.dtype) + padded_audio = torch.cat((pre_padding, audio, post_padding), dim=0) + padded_audio_len = audio_len + pre_padding_len + post_padding_len + + # pad the EOU labels + pre_padding_eou_len = self._audio_len_to_frame_len(pre_padding_len) + post_padding_eou_len = self._audio_len_to_frame_len(post_padding_len) + pre_padding_eou = torch.zeros(pre_padding_eou_len, dtype=eou_targets.dtype) + post_padding_eou = torch.zeros(post_padding_eou_len, dtype=eou_targets.dtype) + padded_eou_targets = torch.cat((pre_padding_eou, eou_targets, post_padding_eou), dim=0) + padded_eou_targets = self._repeat_eou_labels(padded_eou_targets) + return padded_audio, padded_audio_len, padded_eou_targets + + def _maybe_augment_audio(self, audio: torch.Tensor, audio_len: torch.Tensor): + """ + Apply augmentation to the audio signal if augmentor is provided. + Args: + audio: torch.Tensor of a single audio signal, shape [T] + audio_len: torch.Tensor of audio signal length, shape [1] + Returns: + augmented_audio: torch.Tensor of augmented audio signal, shape [T] + augmented_audio_len: torch.Tensor of augmented audio signal length, shape [1] + """ + if self.augmentor is None: + return audio, audio_len + + # Cast to AudioSegment + audio_segment = AudioSegment( + samples=audio[:audio_len].numpy(), + sample_rate=self.sample_rate, + offset=0, + duration=audio_len.item() / self.sample_rate, + ) + # Apply augmentation + self.augmentor.perturb(audio_segment) + audio = torch.from_numpy(audio_segment.samples).float() + audio_len = audio.size(0) + + return audio, audio_len diff --git a/nemo/collections/asr/losses/ssl_losses/mlm.py b/nemo/collections/asr/losses/ssl_losses/mlm.py index 424374869c3d..4ed6f580bbb2 100644 --- a/nemo/collections/asr/losses/ssl_losses/mlm.py +++ b/nemo/collections/asr/losses/ssl_losses/mlm.py @@ -65,11 +65,14 @@ def forward( if masks is None: masks = spec_masks - # B,D,T -> B,T,D - masks = masks.transpose(1, 2) + if masks is None: + masks = torch.ones_like(decoder_outputs, dtype=torch.bool) + else: + # B,D,T -> B,T,D + masks = masks.transpose(1, 2) - masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) - masks = masks.mean(-1) > self.mask_threshold + masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) + masks = masks.mean(-1) > self.mask_threshold out_masked_only = decoder_outputs[masks] targets = F.pad(targets, (0, masks.shape[-1] - targets.shape[-1])) diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 719af4adcd3b..0011de8adc69 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import List, Optional, Tuple, Union import editdistance @@ -238,6 +239,7 @@ def on_validation_epoch_end(self): log_prediction: Whether to log a single decoded sample per call. batch_dim_index: Index corresponding to batch dimension. (For RNNT.) dist_dync_on_step: Whether to perform reduction on forward pass of metric. + return_hypotheses: Whether to return the hypotheses. Returns: res: a tuple of 3 zero dimensional float32 ``torch.Tensor` objects: a WER score, a sum of Levenstein's @@ -255,6 +257,7 @@ def __init__( batch_dim_index=0, dist_sync_on_step=False, sync_on_compute=True, + return_hypotheses=False, **kwargs, ): super().__init__(dist_sync_on_step=dist_sync_on_step, sync_on_compute=sync_on_compute) @@ -264,30 +267,33 @@ def __init__( self.log_prediction = log_prediction self.fold_consecutive = fold_consecutive self.batch_dim_index = batch_dim_index + self.return_hypotheses = return_hypotheses self.decode = None if isinstance(self.decoding, AbstractRNNTDecoding): self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids: self.decoding.rnnt_decoder_predictions_tensor( - encoder_output=predictions, encoded_lengths=predictions_lengths + encoder_output=predictions, encoded_lengths=predictions_lengths, return_hypotheses=return_hypotheses ) elif isinstance(self.decoding, AbstractCTCDecoding): self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids: self.decoding.ctc_decoder_predictions_tensor( decoder_outputs=predictions, decoder_lengths=predictions_lengths, fold_consecutive=self.fold_consecutive, + return_hypotheses=return_hypotheses, ) elif isinstance(self.decoding, AbstractMultiTaskDecoding): self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids: self.decoding.decode_predictions_tensor( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, decoder_input_ids=input_ids, - return_hypotheses=False, + return_hypotheses=return_hypotheses, ) else: raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}") self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.hypotheses = None def update( self, @@ -352,8 +358,22 @@ def update( self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype) self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype) + self.hypotheses = hypotheses + return None def compute(self): scores = self.scores.detach().float() words = self.words.detach().float() return scores / words, scores, words + + def reset(self): + super().reset() + self.hypotheses = None + + def get_hypotheses(self): + """ + Returns the hypotheses generated during the last call to update. + """ + if self.hypotheses is None: + raise ValueError("No hypotheses available. Please call update() first.") + return deepcopy(self.hypotheses) diff --git a/nemo/collections/asr/models/asr_eou_models.py b/nemo/collections/asr/models/asr_eou_models.py new file mode 100644 index 000000000000..4cb0b8f6076c --- /dev/null +++ b/nemo/collections/asr/models/asr_eou_models.py @@ -0,0 +1,967 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import ( + EOB_LABEL, + EOB_STRING, + EOU_LABEL, + EOU_STRING, + AudioToTextEOUBatch, + LhotseSpeechToTextBpeEOUDataset, +) +from nemo.collections.asr.metrics.wer import WER +from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel, EncDecRNNTBPEModel +from nemo.collections.asr.parts.mixins import TranscribeConfig +from nemo.collections.asr.parts.utils.eou_utils import ( + EOUResult, + cal_eou_metrics_from_frame_labels, + flatten_nested_list, +) +from nemo.collections.asr.parts.utils.manifest_utils import write_manifest +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.utils import move_data_to_device +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.utils import logging + +__all__ = ['EncDecRNNTBPEEOUModel', 'EncDecHybridRNNTCTCBPEEOUModel'] + + +@dataclass +class EOUPrediction: + eou_probs: Optional[List[float]] = None + eob_probs: Optional[List[float]] = None + eou_preds: Optional[List[bool]] = None + eob_preds: Optional[List[bool]] = None + + +class ASREOUModelMixin: + def __init__(self): + if not hasattr(self, 'tokenizer'): + self.tokenizer = None + if not hasattr(self, 'eou_token'): + self.eou_token = None + if not hasattr(self, 'eob_token'): + self.eob_token = None + if not hasattr(self, 'frame_len_in_secs'): + self.frame_len_in_secs = None + + def setup_eou_mixin(self, eou_token: int, eob_token: int, frame_len_in_secs: float, tokenizer): + if getattr(self, 'eou_token', None) is None: + self.eou_token = eou_token + if getattr(self, 'eob_token', None) is None: + self.eob_token = eob_token + if getattr(self, 'frame_len_in_secs', None) is None: + self.frame_len_in_secs = frame_len_in_secs + if getattr(self, 'tokenizer', None) is None: + self.tokenizer = tokenizer + + def _patch_decoding_cfg(self, cfg: DictConfig): + """ + Patch the decoding config as needed for EOU computation + """ + with open_dict(cfg): + cfg.decoding.preserve_alignments = True + cfg.decoding.compute_timestamps = True + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ + PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + """ + batch = move_data_to_device(batch, device) + return batch + + def _get_text_from_tokens(self, tokens: torch.Tensor, tokens_len: Optional[torch.Tensor] = None) -> List[str]: + """ + Convert tokens to text. + Args: + tokens: tensor of tokens + Returns: + text: list of text + """ + text_list = [] + for i in range(len(tokens)): + tokens_i = tokens[i] + if tokens_len is not None: + tokens_i = tokens[i][: tokens_len[i]] + tokens_i = [int(x) for x in tokens_i if x < self.tokenizer.vocab_size] + text = self.tokenizer.ids_to_text(tokens_i) + text_list.append(text) + return text_list + + def _get_eou_predictions_from_hypotheses( + self, hypotheses: List[Hypothesis], batch: AudioToTextEOUBatch + ) -> List[EOUPrediction]: + """ + Get EOU predictions from the hypotheses. + Args: + hypotheses: batch of hypotheses + Returns: + eou_predictions: list of EOU predictions + """ + eou_predictions = [] + + for hyp in hypotheses: + # Process one hypothesis at a time + eou_probs = [] + eob_probs = [] + eou_preds = [] + eob_preds = [] + if isinstance(hyp.alignments, tuple): + # CTC + probs = torch.softmax(hyp.alignments[0], dim=-1) # [time, num_classes] + tokens = hyp.alignments[1] + eou_probs = probs[:, self.eou_token].tolist() + eob_probs = probs[:, self.eob_token].tolist() + eou_preds = [int(x) == self.eou_token for x in tokens] + eob_preds = [int(x) == self.eob_token for x in tokens] + else: + # RNNT, each timestamp has a list of (prob, token) tuples + for alignment in hyp.alignments: + # Process for each timestamp + probs = torch.softmax(torch.stack([a[0] for a in alignment], dim=0), dim=-1) # unfold RNNT preds + tokens = torch.stack([a[1] for a in alignment], dim=0) # unfold RNNT preds + + # Get the max prob for eou and eob + # and check if eou and eob are predicted + max_eou_prob = probs[:, self.eou_token].max().item() + max_eob_prob = probs[:, self.eob_token].max().item() + eou_pred = torch.any(tokens == self.eou_token).item() + eob_pred = torch.any(tokens == self.eob_token).item() + + eou_probs.append(max_eou_prob) + eob_probs.append(max_eob_prob) + eou_preds.append(eou_pred) + eob_preds.append(eob_pred) + + eou_predictions.append( + EOUPrediction( + eou_probs=eou_probs, + eob_probs=eob_probs, + eou_preds=eou_preds, + eob_preds=eob_preds, + ) + ) + + return eou_predictions + + def _pad_to_same_length(self, eou_labels: List[float], eou_preds: List[float]) -> Tuple[List[float], List[float]]: + """ + Pad the EOU labels and predictions to the same length. + Args: + eou_labels: list of EOU labels + eou_preds: list of EOU predictions + Returns: + eou_labels: list of EOU labels, padded to the same length + eou_preds: list of EOU predictions, padded to the same length + """ + if len(eou_labels) < len(eou_preds): + eou_labels = eou_labels + [0] * (len(eou_preds) - len(eou_labels)) + elif len(eou_labels) > len(eou_preds): + eou_preds = eou_preds + [0] * (len(eou_labels) - len(eou_preds)) + return eou_labels, eou_preds + + def _calculate_eou_metrics( + self, eou_predictions: List[EOUPrediction], batch: AudioToTextEOUBatch + ) -> Tuple[List[EOUResult], List[EOUResult]]: + """ + Calculate EOU metrics. + Args: + eou_predictions: list of EOU predictions + batch: batch of data + Returns: + eou_metrics_list: list of EOU metrics, each is of type EOUResult + eob_metrics_list: list of EOB metrics, each is of type EOUResult + """ + # Get the ground truth EOU labels + eou_labels = batch.eou_targets + eou_labels_len = batch.eou_target_lengths + + # Calculate EOU metrics + eou_metrics_list = [] + eob_metrics_list = [] + for i, eou_prediction in enumerate(eou_predictions): + eou_preds_i = [float(x) for x in eou_prediction.eou_preds] + eob_preds_i = [float(x) for x in eou_prediction.eob_preds] + + eou_labels_i = (eou_labels[i][: eou_labels_len[i]] == EOU_LABEL).float().tolist() + eob_labels_i = (eou_labels[i][: eou_labels_len[i]] == EOB_LABEL).float().tolist() + + # Pad the EOU labels and predictions to the same length with zeros + eou_labels_i, eou_preds_i = self._pad_to_same_length(eou_labels_i, eou_preds_i) + eob_labels_i, eob_preds_i = self._pad_to_same_length(eob_labels_i, eob_preds_i) + + # Calculate EOU metrics + eou_metrics: EOUResult = cal_eou_metrics_from_frame_labels( + prediction=eou_preds_i, + reference=eou_labels_i, + threshold=0.0, + collar=0.0, + frame_len_in_secs=self.frame_len_in_secs, + ) + + eob_metrics = cal_eou_metrics_from_frame_labels( + prediction=eob_preds_i, + reference=eob_labels_i, + threshold=0.0, + collar=0.0, + frame_len_in_secs=self.frame_len_in_secs, + ) + + eou_metrics_list.append(eou_metrics) + eob_metrics_list.append(eob_metrics) + + return eou_metrics_list, eob_metrics_list + + def _get_percentiles(self, values: List[float], percentiles: List[float], tag: str = "") -> Dict[str, float]: + """ + Get the percentiles of a list of values. + Args: + values: list of values + percentiles: list of percentiles + Returns: + metrics: Dict of percentiles + """ + if len(values) == 0: + return [0.0] * len(percentiles) + results = np.percentile(values, percentiles).tolist() + metrics = {} + if tag: + tag += "_" + for i, p in enumerate(percentiles): + metrics[f'{tag}p{int(p)}'] = float(results[i]) + return metrics + + def _aggregate_eou_metrics(self, outputs: List[dict], mode: str, is_ctc: bool = False): + if f'{mode}_eou_metrics' not in outputs[0] and not is_ctc: + return {} + if f'{mode}_eou_metrics_ctc' not in outputs[0] and is_ctc: + return {} + + # Aggregate EOU/EOB metrics + eou_metrics: List[EOUResult] = [] + eob_metrics: List[EOUResult] = [] + for x in outputs: + if is_ctc: + eou_metrics.extend(x[f'{mode}_eou_metrics_ctc']) + eob_metrics.extend(x[f'{mode}_eob_metrics_ctc']) + else: + eou_metrics.extend(x[f'{mode}_eou_metrics']) + eob_metrics.extend(x[f'{mode}_eob_metrics']) + num_eou_utterances = sum([x.num_utterances for x in eou_metrics]) + eou_latency = flatten_nested_list([x.latency for x in eou_metrics]) + eou_early_cutoff = flatten_nested_list([x.early_cutoff for x in eou_metrics]) + + num_eob_utterances = sum([x.num_utterances for x in eob_metrics]) + eob_latency = flatten_nested_list([x.latency for x in eob_metrics]) + eob_early_cutoff = flatten_nested_list([x.early_cutoff for x in eob_metrics]) + + eou_avg_num_early_cutoff = len(eou_early_cutoff) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + eob_avg_num_early_cutoff = len(eob_early_cutoff) / num_eob_utterances if num_eob_utterances > 0 else 0.0 + if len(eou_latency) == 0: + eou_latency = [0.0] + if len(eou_early_cutoff) == 0: + eou_early_cutoff = [0.0] + if len(eob_latency) == 0: + eob_latency = [0.0] + if len(eob_early_cutoff) == 0: + eob_early_cutoff = [0.0] + + eou_missing = [x.missing for x in eou_metrics] + eob_missing = [x.missing for x in eob_metrics] + + tensorboard_logs = {} + target_percentiles = [50, 90, 95] + eou_latency_metrics = self._get_percentiles(eou_latency, target_percentiles, tag=f'{mode}_eou_latency') + eou_early_cutoff_metrics = self._get_percentiles( + eou_early_cutoff, target_percentiles, tag=f'{mode}_eou_early_cutoff' + ) + eob_latency_metrics = self._get_percentiles(eob_latency, target_percentiles, tag=f'{mode}_eob_latency') + eob_early_cutoff_metrics = self._get_percentiles( + eob_early_cutoff, target_percentiles, tag=f'{mode}_eob_early_cutoff' + ) + + tensorboard_logs.update(eou_latency_metrics) + tensorboard_logs.update(eou_early_cutoff_metrics) + tensorboard_logs.update(eob_latency_metrics) + tensorboard_logs.update(eob_early_cutoff_metrics) + + tensorboard_logs[f'{mode}_eou_early_cutoff_avg_num'] = eou_avg_num_early_cutoff + tensorboard_logs[f'{mode}_eob_early_cutoff_avg_num'] = eob_avg_num_early_cutoff + + tensorboard_logs[f'{mode}_eou_missing'] = ( + sum(eou_missing) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + ) + tensorboard_logs[f'{mode}_eob_missing'] = ( + sum(eob_missing) / num_eob_utterances if num_eob_utterances > 0 else 0.0 + ) + + return tensorboard_logs + + @rank_zero_only + def _maybe_save_predictions( + self, outputs: List[Dict], mode: str = "val", dataloader_idx: int = 0 + ) -> Optional[Path]: + """ + Save predictions to disk. + Args: + outputs: list of outputs + mode: mode of the model, either 'val' or 'test' + Returns: + Path object if predictions are saved, None otherwise. + """ + + if not self.cfg.get('save_pred_to_file', None): + return None + + output_file = Path(self.cfg.save_pred_to_file) + output_file.parent.mkdir(parents=True, exist_ok=True) + + if getattr(self, '_validation_names', None): + output_file = output_file.with_name(f"{self._validation_names[dataloader_idx]}_{output_file.name}") + else: + output_file = output_file.with_suffix(f'.{dataloader_idx}.json') + + manifest = [] + for output in outputs: + for i in range(len(output[f'{mode}_sample_id'])): + item = { + "sample_id": output[f'{mode}_sample_id'][i], + "audio_filepath": output[f'{mode}_audio_filepath'][i], + "eou_text": output[f'{mode}_text_gt'][i], + "eou_pred_text": output[f'{mode}_text_pred'][i], + "is_backchannel": bool(str(output[f'{mode}_text_gt'][i]).endswith(EOB_STRING)), + } + if f"{mode}_text_pred_ctc" in output: + item["eou_pred_text_ctc"] = output[f"{mode}_text_pred_ctc"][i] + + eou_metrics = {f"eou_{k}": v for k, v in output[f"{mode}_eou_metrics"][i].to_dict().items()} + eob_metrics = {f"eob_{k}": v for k, v in output[f"{mode}_eob_metrics"][i].to_dict().items()} + item.update(eou_metrics) + item.update(eob_metrics) + manifest.append(item) + write_manifest(output_file, manifest) + logging.info(f"Predictions saved to {output_file}") + return output_file + + +class EncDecRNNTBPEEOUModel(EncDecRNNTBPEModel, ASREOUModelMixin): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + + self._patch_decoding_cfg(cfg) + super().__init__(cfg=cfg, trainer=trainer) + + self.eou_token = self.tokenizer.token_to_id(EOU_STRING) + self.eob_token = self.tokenizer.token_to_id(EOB_STRING) + self.frame_len_in_secs = self.cfg.preprocessor.window_stride * self.cfg.encoder.subsampling_factor + + self.setup_eou_mixin(self.eou_token, self.eob_token, self.frame_len_in_secs, self.tokenizer) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + return_hypotheses=True, + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + cfg = OmegaConf.create(config) if not isinstance(config, DictConfig) else config + dataset = LhotseSpeechToTextBpeEOUDataset( + cfg=cfg, tokenizer=self.tokenizer, return_cuts=config.get("do_transcribe", False) + ) + return get_lhotse_dataloader_from_config( + config, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=dataset, + tokenizer=self.tokenizer, + ) + + def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): + if isinstance(batch, AudioToTextEOUBatch): + signal = batch.audio_signal + signal_len = batch.audio_lengths + else: + signal = batch[0] + signal_len = batch[1] + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + output = dict(encoded=encoded, encoded_len=encoded_len) + return output + + def training_step(self, batch: AudioToTextEOUBatch, batch_nb): + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + # During training, loss must be computed, so decoder forward is necessary + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + sample_id = self._trainer.global_step + else: + log_every_n_steps = 1 + sample_id = batch_nb + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + # Compute full joint and loss + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if (sample_id + 1) % log_every_n_steps == 0: + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + _, scores, words = self.wer.compute() + self.wer.reset() + tensorboard_logs.update({'training_batch_wer': scores.float() / words}) + + else: + # If experimental fused Joint-Loss-WER is used + if (sample_id + 1) % log_every_n_steps == 0: + compute_wer = True + else: + compute_wer = False + + # Fused joint step + loss_value, wer, _, _ = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoder, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=transcript_len, + compute_wer=compute_wer, + ) + + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value) + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + tensorboard_logs = { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + + if compute_wer: + tensorboard_logs.update({'training_batch_wer': wer}) + + # Log items + self.log_dict(tensorboard_logs) + + # Preserve batch acoustic model T and language model U parameters if normalizing + if self._optim_normalize_joint_txu: + self._optim_normalize_txu = [encoded_len.max(), transcript_len.max()] + + return {'loss': loss_value} + + def predict_step(self, batch: AudioToTextEOUBatch, batch_idx, dataloader_idx=0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + best_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False + ) + + return list(best_hyp_text) + + def validation_pass(self, batch: AudioToTextEOUBatch, batch_idx: int, dataloader_idx: int = 0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + # forward() only performs encoder forward + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + if self.cfg.get('save_pred_to_file', None): + text_gt = self._get_text_from_tokens(transcript, transcript_len) + tensorboard_logs['val_sample_id'] = batch.sample_ids + tensorboard_logs['val_audio_filepath'] = batch.audio_filepaths + tensorboard_logs['val_text_gt'] = text_gt + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + hypotheses = self.wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + if self.cfg.get('calculate_eou_metrics', True): + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + else: + eou_metrics_list = [] + eob_metrics_list = [] + + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + keep_hypotheses=True, + ) + + hypotheses = self.joint.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + if self.cfg.get('calculate_eou_metrics', True): + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + else: + eou_metrics_list = [] + eob_metrics_list = [] + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return tensorboard_logs + + def multi_inference_epoch_end(self, outputs, dataloader_idx: int = 0, mode: str = "val"): + assert mode in ['val', 'test'], f"Invalid mode: {mode}. Must be 'val' or 'test'." + + if not outputs: + logging.warning( + f"No outputs received for {mode} dataloader {dataloader_idx}. Skipping epoch end processing." + ) + return {} + + self._maybe_save_predictions(outputs, mode=mode, dataloader_idx=dataloader_idx) + + # Aggregate WER metrics + if self.compute_eval_loss: + loss_mean = torch.stack([x[f'{mode}_loss'] for x in outputs]).mean() + loss_log = {f'{mode}_loss': loss_mean} + else: + loss_log = {} + wer_num = torch.stack([x[f'{mode}_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x[f'{mode}_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**loss_log, f'{mode}_wer': wer_num.float() / wer_denom} + + eou_metrics = {} + if self.cfg.get('calculate_eou_metrics', True): + eou_metrics = self._aggregate_eou_metrics(outputs, mode=mode) + tensorboard_logs.update(eou_metrics) + + return {**loss_log, 'log': tensorboard_logs} + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='val') + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='test') + + @property + def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + return { + "cls": AudioToTextEOUBatch, + "inputs": [ + {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input", "name": "audio_signal"}, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "input", "name": "audio_lengths"}, + { + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "name": "text_tokens", + "vocab_size": self.tokenizer.vocab_size, + }, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "output", "name": "text_token_lengths"}, + { + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "name": "eou_targets", + "vocab_size": 4, + }, + {"type": NeuralType(("B",), LengthsType()), "seq_length": "output", "name": "eou_target_lengths"}, + ], + } + + +class EncDecHybridRNNTCTCBPEEOUModel(EncDecHybridRNNTCTCBPEModel, ASREOUModelMixin): + def __init__(self, cfg: DictConfig, trainer): + self._patch_decoding_cfg(cfg) + if cfg.aux_ctc.get('decoding', None) is not None: + with open_dict(cfg): + cfg.aux_ctc.decoding.preserve_alignments = True + cfg.aux_ctc.decoding.compute_timestamps = True + + super().__init__(cfg=cfg, trainer=trainer) + + self.eou_token = self.tokenizer.token_to_id(EOU_STRING) + self.eob_token = self.tokenizer.token_to_id(EOB_STRING) + self.frame_len_in_secs = self.cfg.preprocessor.window_stride * self.cfg.encoder.subsampling_factor + self.setup_eou_mixin(self.eou_token, self.eob_token, self.frame_len_in_secs, self.tokenizer) + + self.wer = WER( + decoding=self.decoding, + batch_dim_index=0, + use_cer=self._cfg.get('use_cer', False), + log_prediction=self._cfg.get('log_prediction', True), + dist_sync_on_step=True, + return_hypotheses=True, + ) + + self.ctc_wer = WER( + decoding=self.ctc_decoding, + use_cer=self.cfg.aux_ctc.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self.cfg.get("log_prediction", False), + return_hypotheses=True, + ) + + # Setup fused Joint step if flag is set + if self.joint.fuse_loss_wer: + self.joint.set_loss(self.loss) + self.joint.set_wer(self.wer) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + cfg = OmegaConf.create(config) if not isinstance(config, DictConfig) else config + dataset = LhotseSpeechToTextBpeEOUDataset( + cfg=cfg, tokenizer=self.tokenizer, return_cuts=config.get("do_transcribe", False) + ) + return get_lhotse_dataloader_from_config( + config, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=dataset, + tokenizer=self.tokenizer, + ) + + def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): + if isinstance(batch, AudioToTextEOUBatch): + signal = batch.audio_signal + signal_len = batch.audio_lengths + else: + signal = batch[0] + signal_len = batch[1] + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + output = dict(encoded=encoded, encoded_len=encoded_len) + return output + + def training_step(self, batch: AudioToTextEOUBatch, batch_nb): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + new_batch = (signal, signal_len, transcript, transcript_len) + return super().training_step(new_batch, batch_nb) + + def predict_step(self, batch: AudioToTextEOUBatch, batch_idx, dataloader_idx=0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + sample_ids = batch.sample_ids + new_batch = (signal, signal_len, transcript, transcript_len, sample_ids) + return super().predict_step(new_batch, batch_idx, dataloader_idx) + + def validation_pass(self, batch: AudioToTextEOUBatch, batch_idx: int, dataloader_idx: int = 0): + signal = batch.audio_signal + signal_len = batch.audio_lengths + transcript = batch.text_tokens + transcript_len = batch.text_token_lengths + + encoded, encoded_len = self.forward(input_signal=signal, input_signal_length=signal_len) + del signal + + tensorboard_logs = {} + + if self.cfg.get('save_pred_to_file', None): + text_gt = self._get_text_from_tokens(transcript, transcript_len) + tensorboard_logs['val_sample_id'] = batch.sample_ids + tensorboard_logs['val_audio_filepath'] = batch.audio_filepaths + tensorboard_logs['val_text_gt'] = text_gt + + loss_value = None + + # If experimental fused Joint-Loss-WER is not used + if not self.joint.fuse_loss_wer: + if self.compute_eval_loss: + decoder, target_length, states = self.decoder(targets=transcript, target_length=transcript_len) + joint = self.joint(encoder_outputs=encoded, decoder_outputs=decoder) + + loss_value = self.loss( + log_probs=joint, targets=transcript, input_lengths=encoded_len, target_lengths=target_length + ) + tensorboard_logs['val_loss'] = loss_value + + self.wer.update( + predictions=encoded, + predictions_lengths=encoded_len, + targets=transcript, + targets_lengths=transcript_len, + ) + + hypotheses = self.wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + tensorboard_logs['val_text_pred'] = text_pred + + else: + # If experimental fused Joint-Loss-WER is used + compute_wer = True + + if self.compute_eval_loss: + decoded, target_len, states = self.decoder(targets=transcript, target_length=transcript_len) + else: + decoded = None + target_len = transcript_len + + # Fused joint step + loss_value, wer, wer_num, wer_denom = self.joint( + encoder_outputs=encoded, + decoder_outputs=decoded, + encoder_lengths=encoded_len, + transcripts=transcript, + transcript_lengths=target_len, + compute_wer=compute_wer, + keep_hypotheses=True, + ) + hypotheses = self.joint.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred = self._get_text_from_tokens([x.y_sequence for x in hypotheses]) + tensorboard_logs['val_text_pred'] = text_pred + + eou_predictions = self._get_eou_predictions_from_hypotheses(hypotheses, batch) + + eou_metrics_list, eob_metrics_list = self._calculate_eou_metrics(eou_predictions, batch) + + if loss_value is not None: + tensorboard_logs['val_loss'] = loss_value + + tensorboard_logs['val_wer_num'] = wer_num + tensorboard_logs['val_wer_denom'] = wer_denom + tensorboard_logs['val_wer'] = wer + tensorboard_logs['val_eou_metrics'] = eou_metrics_list + tensorboard_logs['val_eob_metrics'] = eob_metrics_list + + log_probs = self.ctc_decoder(encoder_output=encoded) + if self.compute_eval_loss: + ctc_loss = self.ctc_loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + tensorboard_logs['val_ctc_loss'] = ctc_loss + tensorboard_logs['val_rnnt_loss'] = loss_value + loss_value = (1 - self.ctc_loss_weight) * loss_value + self.ctc_loss_weight * ctc_loss + tensorboard_logs['val_loss'] = loss_value + + self.ctc_wer.update( + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, + ) + hypotheses_ctc = self.ctc_wer.get_hypotheses() + + if self.cfg.get('save_pred_to_file', None): + text_pred_ctc = self._get_text_from_tokens([x.y_sequence for x in hypotheses_ctc]) + tensorboard_logs['val_text_pred_ctc'] = text_pred_ctc + + eou_predictions_ctc = self._get_eou_predictions_from_hypotheses(hypotheses_ctc, batch) + eou_metrics_list_ctc, eob_metrics_list_ctc = self._calculate_eou_metrics(eou_predictions_ctc, batch) + + ctc_wer, ctc_wer_num, ctc_wer_denom = self.ctc_wer.compute() + self.ctc_wer.reset() + + tensorboard_logs['val_wer_num_ctc'] = ctc_wer_num + tensorboard_logs['val_wer_denom_ctc'] = ctc_wer_denom + tensorboard_logs['val_wer_ctc'] = ctc_wer + tensorboard_logs['val_eou_metrics_ctc'] = eou_metrics_list_ctc + tensorboard_logs['val_eob_metrics_ctc'] = eob_metrics_list_ctc + + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + loss_value, additional_logs = self.add_interctc_losses( + loss_value, + transcript, + transcript_len, + compute_wer=True, + compute_loss=self.compute_eval_loss, + log_wer_num_denom=True, + log_prefix="val_", + ) + if self.compute_eval_loss: + # overriding total loss value. Note that the previous + # rnnt + ctc loss is available in metrics as "val_final_loss" now + tensorboard_logs['val_loss'] = loss_value + tensorboard_logs.update(additional_logs) + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + return tensorboard_logs + + def multi_inference_epoch_end(self, outputs, dataloader_idx: int = 0, mode: str = "val"): + assert mode in ['val', 'test'], f"Invalid mode: {mode}. Must be 'val' or 'test'." + self._maybe_save_predictions(outputs, mode=mode, dataloader_idx=dataloader_idx) + + # Aggregate WER metrics + if self.compute_eval_loss: + loss_mean = torch.stack([x[f'{mode}_loss'] for x in outputs]).mean() + loss_log = {f'{mode}_loss': loss_mean} + else: + loss_log = {} + wer_num = torch.stack([x[f'{mode}_wer_num'] for x in outputs]).sum() + wer_denom = torch.stack([x[f'{mode}_wer_denom'] for x in outputs]).sum() + tensorboard_logs = {**loss_log, f'{mode}_wer': wer_num.float() / wer_denom} + + if self.ctc_loss_weight > 0: + ctc_wer_num = torch.stack([x['val_wer_num_ctc'] for x in outputs]).sum() + ctc_wer_denom = torch.stack([x['val_wer_denom_ctc'] for x in outputs]).sum() + tensorboard_logs['val_wer_ctc'] = ctc_wer_num.float() / ctc_wer_denom + + eou_metrics = self._aggregate_eou_metrics(outputs, mode) + tensorboard_logs.update(eou_metrics) + + eou_metrics_ctc = self._aggregate_eou_metrics(outputs, mode, is_ctc=True) + for key, value in eou_metrics_ctc.items(): + tensorboard_logs[f'{key}_ctc'] = value + + return {**loss_log, 'log': tensorboard_logs} + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='val') + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + return self.multi_inference_epoch_end(outputs, dataloader_idx, mode='test') diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index 1d80636fa1be..7259d077809e 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -20,7 +20,11 @@ SpectrogramAugmentation, ) from nemo.collections.asr.modules.beam_search_decoder import BeamSearchDecoderWithLM # noqa: F401 -from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder, ConformerEncoderAdapter # noqa: F401 +from nemo.collections.asr.modules.conformer_encoder import ( # noqa: F401 + ConformerEncoder, + ConformerEncoderAdapter, + ConformerMultiLayerFeatureExtractor, +) from nemo.collections.asr.modules.conv_asr import ( # noqa: F401 ConvASRDecoder, ConvASRDecoderClassification, @@ -41,11 +45,42 @@ SampledRNNTJoint, StatelessTransducerDecoder, ) -from nemo.collections.asr.modules.ssl_modules import ( # noqa: F401 - ConformerMultiLayerFeatureExtractor, +from nemo.collections.asr.modules.ssl_modules import ( ConformerMultiLayerFeaturePreprocessor, ConvFeatureMaksingWrapper, MultiSoftmaxDecoder, RandomBlockMasking, RandomProjectionVectorQuantizer, ) + +__all__ = [ + 'AudioToMelSpectrogramPreprocessor', + 'AudioToMFCCPreprocessor', + 'CropOrPadSpectrogramAugmentation', + 'MaskedPatchAugmentation', + 'SpectrogramAugmentation', + 'BeamSearchDecoderWithLM', + 'ConformerEncoder', + 'ConformerEncoderAdapter', + 'ConformerMultiLayerFeatureExtractor', + 'ConvASRDecoder', + 'ConvASRDecoderClassification', + 'ConvASRDecoderReconstruction', + 'ConvASREncoder', + 'ConvASREncoderAdapter', + 'ECAPAEncoder', + 'ParallelConvASREncoder', + 'SpeakerDecoder', + 'HATJoint', + 'LSTMDecoder', + 'RNNTDecoder', + 'RNNTDecoderJointSSL', + 'RNNTJoint', + 'SampledRNNTJoint', + 'StatelessTransducerDecoder', + 'ConformerMultiLayerFeaturePreprocessor', + 'ConvFeatureMaksingWrapper', + 'MultiSoftmaxDecoder', + 'RandomBlockMasking', + 'RandomProjectionVectorQuantizer', +] diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 0e25ea7767f4..aa63f588cd63 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -56,7 +56,7 @@ ) from nemo.utils import logging -__all__ = ['ConformerEncoder'] +__all__ = ['ConformerEncoder', 'ConformerMultiLayerFeatureExtractor'] class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): @@ -1273,17 +1273,34 @@ class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable, AccessMixin) def __init__( self, encoder: ConformerEncoder, - layer_idx_list: List[int], - aggregator: NeuralModule = None, + layer_idx_list: Optional[List[int]] = None, + aggregator: Optional[NeuralModule] = None, detach: bool = False, convert_to_cpu: bool = False, ): + """ + This class is used to extract features from different layers of the ConformerEncoder. + Args: + encoder: ConformerEncoder instance. + layer_idx_list: List of layer indices to extract features from. If None, all layers are extracted. + aggregator: Aggregator instance. If None, the features are returned as a list. + detach: If True, the features are detached from the graph. + convert_to_cpu: If True, the features are converted to CPU. + """ super().__init__() self.encoder = encoder - self.layer_idx_list = [int(lyr_idx) for lyr_idx in layer_idx_list] - for x in self.layer_idx_list: - if x < 0 or x >= len(encoder.layers): - raise ValueError(f"layer index {x} out of range [0, {len(encoder.layers)})") + self.num_layers = len(encoder.layers) + self.layer_idx_list = [] + if not layer_idx_list: + layer_idx_list = list(range(self.num_layers)) + for lid in layer_idx_list: + if lid < -self.num_layers or lid >= self.num_layers: + raise ValueError(f"Invalid layer index {lid} for ConformerEncoder with {self.num_layers} layers.") + if lid < 0: + lid = self.num_layers + lid + self.layer_idx_list.append(lid) + self.layer_idx_list.sort() + logging.info(f"Extracting ConformerEncoder features from layers: {self.layer_idx_list}") self.enc_access_cfg = { "interctc": { "capture_layers": self.layer_idx_list, @@ -1296,7 +1313,13 @@ def __init__( def forward( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring + """ + Args: + same interface as ConformerEncoder.forward() + Returns: + - Tuple[List[Tensor[B,D,T]], List[Tensor[B]]] if aggregator is None + - Tuple[Tensor[B,H,T], Tensor[B]] if aggregator is not None, where H is the hidden size of the aggregator + """ old_access_flag = self.is_access_enabled(guid=getattr(self, "model_guid", None)) self.update_access_cfg(self.enc_access_cfg, guid=getattr(self, "model_guid", None)) self.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) @@ -1338,7 +1361,7 @@ def forward( # End of the adapted chunk if self.aggregator is not None: - return self.aggregator(encoded_list, encoded_len_list) # Tensor[B,D*L,T], Tensor[B] + return self.aggregator(encoded_list, encoded_len_list) # Tensor[B,H,T], Tensor[B] else: return encoded_list, encoded_len_list # List[Tensor[B,D,T]], List[Tensor[B]] diff --git a/nemo/collections/asr/modules/lstm_decoder.py b/nemo/collections/asr/modules/lstm_decoder.py index 9bb60e2fabca..03f6cf6aa875 100644 --- a/nemo/collections/asr/modules/lstm_decoder.py +++ b/nemo/collections/asr/modules/lstm_decoder.py @@ -35,6 +35,7 @@ class LSTMDecoder(NeuralModule, Exportable): vocabulary (vocab): The vocabulary bidirectional (bool): default is False. Whether LSTMs are bidirectional or not num_layers (int): default is 1. Number of LSTM layers stacked + add_blank (bool): default is True. Whether to add a blank token to the vocabulary. """ @property @@ -45,7 +46,16 @@ def input_types(self): def output_types(self): return OrderedDict({"logprobs": NeuralType(('B', 'T', 'D'), LogprobsType())}) - def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidirectional=False, num_layers=1): + def __init__( + self, + feat_in, + num_classes, + lstm_hidden_size, + vocabulary=None, + bidirectional=False, + num_layers=1, + add_blank=True, + ): super().__init__() if vocabulary is not None: @@ -57,7 +67,7 @@ def __init__(self, feat_in, num_classes, lstm_hidden_size, vocabulary=None, bidi self.__vocabulary = vocabulary self._feat_in = feat_in # Add 1 for blank char - self._num_classes = num_classes + 1 + self._num_classes = num_classes + 1 if add_blank else num_classes self.lstm_layer = nn.LSTM( input_size=feat_in, diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 9f5c234977e7..feaf0edfca92 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1332,6 +1332,8 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin) - compute_wer (bool, default false). Whether to compute WER or not for the fused batch. + - keep_hypotheses (bool, default false). Whether to keep the hypotheses of the decoded outputs. + Output - instead of the usual `joint` log prob tensor, the following results can be returned. - loss (optional). Returned if decoder_outputs, transcripts and transript_lengths are not None. @@ -1357,6 +1359,7 @@ def input_types(self): "transcripts": NeuralType(('B', 'T'), LabelsType(), optional=True), "transcript_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "compute_wer": NeuralType(optional=True), + "keep_hypotheses": NeuralType(optional=True), } @property @@ -1469,6 +1472,8 @@ def __init__( # to change, requires running ``model.temperature = T`` explicitly self.temperature = 1.0 + self.hypotheses = None + @typecheck() def forward( self, @@ -1478,6 +1483,7 @@ def forward( transcripts: Optional[torch.Tensor] = None, transcript_lengths: Optional[torch.Tensor] = None, compute_wer: bool = False, + keep_hypotheses: bool = False, ) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]: # encoder = (B, D, T) # decoder = (B, D, U) if passed, else None @@ -1515,6 +1521,7 @@ def forward( wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size + hypotheses = [] # Iterate over batch using fused_batch_size steps for batch_idx in range(0, batch_size, self._fused_batch_size): @@ -1599,6 +1606,9 @@ def forward( targets=sub_transcripts, targets_lengths=sub_transcript_lens, ) + + hyp = self.wer.get_hypotheses() if keep_hypotheses else [] + # Sync and all_reduce on all processes, compute global WER wer, wer_num, wer_denom = self.wer.compute() self.wer.reset() @@ -1609,6 +1619,7 @@ def forward( wers.append(wer) wer_nums.append(wer_num) wer_denoms.append(wer_denom) + hypotheses.extend(hyp) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -1626,8 +1637,19 @@ def forward( wer_num = None wer_denom = None + self.hypotheses = hypotheses if keep_hypotheses else None return losses, wer, wer_num, wer_denom + def get_hypotheses(self): + """ + Returns the hypotheses generated during the last forward pass. + """ + if self.hypotheses is None: + raise ValueError( + "No hypotheses were generated during the last forward pass. Did you set keep_hypotheses=True in forward()?" + ) + return self.hypotheses + def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor: """ Project the encoder output to the joint hidden dimension. diff --git a/nemo/collections/asr/modules/ssl_modules/__init__.py b/nemo/collections/asr/modules/ssl_modules/__init__.py index dcfefd54fa73..f33127bd7d85 100644 --- a/nemo/collections/asr/modules/ssl_modules/__init__.py +++ b/nemo/collections/asr/modules/ssl_modules/__init__.py @@ -17,9 +17,16 @@ SpeakerNoiseAugmentation, ) from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper, RandomBlockMasking -from nemo.collections.asr.modules.ssl_modules.multi_layer_feat import ( - ConformerMultiLayerFeatureExtractor, - ConformerMultiLayerFeaturePreprocessor, -) +from nemo.collections.asr.modules.ssl_modules.multi_layer_feat import ConformerMultiLayerFeaturePreprocessor from nemo.collections.asr.modules.ssl_modules.multi_softmax_decoder import MultiSoftmaxDecoder from nemo.collections.asr.modules.ssl_modules.quantizers import RandomProjectionVectorQuantizer + +__all__ = [ + 'MultiSpeakerNoiseAugmentation', + 'SpeakerNoiseAugmentation', + 'ConvFeatureMaksingWrapper', + 'RandomBlockMasking', + 'ConformerMultiLayerFeaturePreprocessor', + 'MultiSoftmaxDecoder', + 'RandomProjectionVectorQuantizer', +] diff --git a/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py index 490d68c52f04..73ca41438437 100644 --- a/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py +++ b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed import torch.nn as nn -from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor, ConformerEncoder +from nemo.collections.asr.modules import ( + AudioToMelSpectrogramPreprocessor, + ConformerEncoder, + ConformerMultiLayerFeatureExtractor, +) from nemo.core.classes import Exportable, NeuralModule from nemo.core.classes.mixins import AccessMixin -from nemo.utils import logging class Aggregator(nn.Module): @@ -81,85 +84,12 @@ def forward( raise ValueError(f"Unknown mode {self.mode}") -class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable): - def __init__(self, encoder, aggregator: Optional[Callable] = None, layer_idx_list: Optional[List[int]] = None): - """ - Args: - encoder: ConformerEncoder instance. - aggregator: Aggregator instance. - layer_idx_list: List of layer indices to extract features from. - """ - super().__init__() - self.encoder = encoder - self.aggregator = aggregator - self.layer_idx_list = ( - [int(l) for l in layer_idx_list] - if layer_idx_list is not None - else [i for i in range(len(self.encoder.layers))] - ) - for x in self.layer_idx_list: - if x < 0 or x >= len(self.encoder.layers): - raise ValueError(f"layer index {x} out of range [0, {len(self.encoder.layers)})") - logging.info(f"Extracting features from layers {self.layer_idx_list}") - self.access_cfg = { - "interctc": { - "capture_layers": self.layer_idx_list, - }, - "detach": False, - "convert_to_cpu": False, - } - self._is_access_enabled = False - - def forward( - self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - same interface as ConformerEncoder.forward() - Returns: - tuple of aggregated features of shape [B, D, T] and lengths of shape [B] - """ - self.encoder.update_access_cfg(self.access_cfg, guid=getattr(self, "model_guid", None)) - self.encoder.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) - - _ = self.encoder( - audio_signal=audio_signal, - length=length, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - ) - - total_registry = {} - for module_registry in self.encoder.get_module_registry(self.encoder).values(): - for key in module_registry: - if key.startswith("interctc/") and key in total_registry: - raise RuntimeError(f"layer {key} has been logged multiple times!") - total_registry.update(module_registry) - - encoded_list = [] - encoded_len_list = [] - for layer_idx in self.layer_idx_list: - try: - layer_outputs = total_registry[f"interctc/layer_output_{layer_idx}"] - layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] - except KeyError: - raise RuntimeError( - f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of " - "ConformerEncoder layers." - ) - if len(layer_outputs) > 1 or len(layer_lengths) > 1: - raise RuntimeError("Make sure encoder.forward is called exactly one time") - encoded_list.append(layer_outputs[0]) # [B, D, T] - encoded_len_list.append(layer_lengths[0]) # [B] - - self.encoder.reset_registry() - if self.aggregator is None: - return encoded_list, encoded_len_list - return self.aggregator(encoded_list, encoded_len_list) - - class ConformerMultiLayerFeaturePreprocessor(NeuralModule, Exportable, AccessMixin): + """ + This class is used to replace the AudioToMelSpectrogramPreprocessor such that + the input to the actual model encoder is the multi-layer features from a pre-trained ConformerEncoder. + """ + def __init__( self, aggregator: nn.Module, diff --git a/nemo/collections/asr/parts/utils/eou_utils.py b/nemo/collections/asr/parts/utils/eou_utils.py new file mode 100644 index 000000000000..478fc44df3a9 --- /dev/null +++ b/nemo/collections/asr/parts/utils/eou_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, List + +import numpy as np + + +@dataclass +class EOUResult: + """ + A dataclass to store the EOU results. + Args: + latency: List of latencies in seconds. + early_cutoff: List of early cutoffs in seconds. + true_positives: Number of true positives. + false_negatives: Number of false negatives. + false_positives: Number of false positives. + num_utterances: Number of utterances. + num_predictions: Number of predictions. + missing: Number of missing predictions. + """ + + latency: list + early_cutoff: list + true_positives: int + false_negatives: int + false_positives: int + num_utterances: int + num_predictions: int + missing: int + + def to_dict(self) -> Dict[str, float]: + """ + Convert the EOUResult dataclass to a dictionary. + Returns: + Dict: A dictionary representation of the EOUResult. + """ + return { + 'latency': self.latency, + 'early_cutoff': self.early_cutoff, + 'true_positives': self.true_positives, + 'false_negatives': self.false_negatives, + 'false_positives': self.false_positives, + 'num_utterances': self.num_utterances, + 'num_predictions': self.num_predictions, + 'missing': self.missing, + } + + +def flatten_nested_list(nested_list: List[List[float]]) -> List[float]: + """ + Flatten a nested list into a single list. + Args: + nested_list (List[List]): A nested list to be flattened. + Returns: + List: A flattened list. + """ + return [item for sublist in nested_list for item in sublist] + + +def evaluate_eou( + *, prediction: List[dict], reference: List[dict], threshold: float, collar: float, do_sorting: bool = True +) -> EOUResult: + """ + Evaluate end of utterance predictions against reference labels. + Each item in predicition/reference is a dictionary in SegLST containing: + { + "session_id": str, + "start_time": float, # start time in seconds + "end_time": float, # end time in seconds + "words": str, # transcription of the utterance + "audio_filepath": str, # only in prediction + "eou_prob": float, # only in prediction, probability of EOU in range [0.1] + "eou_pred": bool, # only in prediction + "full_text": str, # only in prediction, which is the full transcription up to the end_time + } + + Args: + predictions (List[dict]): List of dictionaries containing predictions. + references (List[dict]): List of dictionaries containing reference labels. + threshold (float): Threshold for considering a prediction as EOU. + collar (float): Collar time in seconds for matching predictions to references. + do_sorting (bool): Whether to sort the predictions and references by start time. + Returns: + EOUResult: A dataclass containing the evaluation results. + """ + + latency = [] + early_cutoff = [] + true_positives = 0 + false_negatives = 0 + false_positives = 0 + num_utterances = len(reference) + num_predictions = len(prediction) + missing = 0 + earlycut_ids = set() + predicted_eou = prediction + if threshold is not None and threshold > 0: + predicted_eou = [p for p in prediction if p["eou_prob"] > threshold] + elif all([hasattr(p, "eou_pred") for p in prediction]): + # If eou_pred is available, use it + predicted_eou = [p for p in prediction if p["eou_pred"]] + + if do_sorting: + predicted_eou = sorted(predicted_eou, key=lambda x: x["start_time"]) + reference = sorted(reference, key=lambda x: x["start_time"]) + + p_idx = 0 + r_idx = 0 + for p_idx in range(len(predicted_eou)): + p = predicted_eou[p_idx] + p_start = p["start_time"] + p_end = p["end_time"] + + while r_idx < len(reference) and reference[r_idx]["end_time"] < p_start: + # Current reference ends before the current predicted utterance starts, find the next reference + r_idx += 1 + + if r_idx >= len(reference): + # No more references to compare against + false_positives += 1 + continue + + r = reference[r_idx] + r_start = r["start_time"] + r_end = r["end_time"] + + if np.abs(p_end - r_end) <= collar: + # Correctly predicted EOU + true_positives += 1 + latency.append(p_end - r_end) + if r_idx in earlycut_ids: + # If this reference was already missed due to early cutoff, we do not count it again + earlycut_ids.remove(r_idx) + r_idx += 1 + elif r_start <= p_end < r_end - collar: + # Early cutoff + # current predicted EOU is within the current reference utterance + false_positives += 1 + early_cutoff.append(r_end - p_end) + earlycut_ids.add(r_idx) + elif r_end + collar < p_end: + # Late EOU + # Current predicted EOU is after the current reference ends + false_negatives += 1 + latency.append(p_end - r_end) + if r_idx in earlycut_ids: + # If this reference was already missed due to early cutoff, we do not count it again + earlycut_ids.remove(r_idx) + r_idx += 1 + else: + # p_end <= r_start + # Current predicted EOU is before the current reference starts + false_positives += 1 + + if r_idx < len(reference): + # There are remaining references that were not matched + false_negatives += len(reference) - r_idx + missing += len(reference) - r_idx + + missing -= len(earlycut_ids) # Remove the references that were missed due to early cutoff + false_negatives -= len(earlycut_ids) # Remove the references that were missed due to early cutoff + return EOUResult( + latency=latency, + early_cutoff=early_cutoff, + true_positives=true_positives, + false_negatives=false_negatives, + false_positives=false_positives, + num_utterances=num_utterances, + num_predictions=num_predictions, + missing=missing, + ) + + +def get_SegLST_from_frame_labels(frame_labels: List[int], frame_len_in_secs: float = 0.08) -> List[dict]: + """ + Convert frame labels to SegLST format. + Args: + frame_labels (List[int]): List of frame labels. + frame_len_in_secs (float): Length of each frame in seconds. + Returns: + List[dict]: List of dictionaries in SegLST format. + """ + seg_lst = [] + start_time = 0.0 + for i, label in enumerate(frame_labels): + if label > 0: + end_time = start_time + frame_len_in_secs * i + seg_lst.append({"start_time": start_time, "end_time": end_time, "eou_prob": label}) + start_time = end_time + return seg_lst + + +def cal_eou_metrics_from_frame_labels( + *, + prediction: List[float], + reference: List[float], + threshold: float = 0.5, + collar: float = 0, + frame_len_in_secs: float = 0.08, +) -> EOUResult: + """ + Calculate EOU metrics from lists of predictions and references. + Args: + prediction (List): List of floats containing predicted EOU probabilities. + reference (List): List of binary floats containing reference EOU probabilities. + threshold (float): Threshold for considering a prediction as EOU. + collar (float): Collar time in seconds for matching predictions to references. + frame_len_in_secs (float): Length of each frame in seconds. + """ + + if len(prediction) != len(reference): + raise ValueError( + f"Prediction ({len(prediction)}) and reference ({len(reference)}) lists must have the same length." + ) + + pred_seg_lst = get_SegLST_from_frame_labels(prediction, frame_len_in_secs) + ref_seg_lst = get_SegLST_from_frame_labels(reference, frame_len_in_secs) + eou_metrics = evaluate_eou( + prediction=pred_seg_lst, reference=ref_seg_lst, threshold=threshold, collar=collar, do_sorting=False + ) + return eou_metrics + + +def get_percentiles(values: List[float], percentiles: List[float], tag: str = "") -> Dict[str, float]: + """ + Get the percentiles of a list of values. + Args: + values: list of values + percentiles: list of percentiles + Returns: + metrics: Dict of percentiles + """ + if len(values) == 0: + return [0.0] * len(percentiles) + results = np.percentile(values, percentiles).tolist() + metrics = {} + if tag: + tag += "_" + for i, p in enumerate(percentiles): + metrics[f'{tag}p{int(p)}'] = float(results[i]) + return metrics + + +def aggregate_eou_metrics(eou_metrics: List[EOUResult], target_percentiles: List = [50, 90, 95]) -> Dict[str, float]: + """ + Aggregate EOU metrics to produce metrics for logging. + Args: + eou_metrics: List of EOUResult objects. + target_percentiles: List of target percentiles. + Returns: + Dict: A dictionary containing the aggregated EOU metrics. + """ + num_eou_utterances = sum([x.num_utterances for x in eou_metrics]) + eou_latency = flatten_nested_list([x.latency for x in eou_metrics]) + eou_early_cutoff = flatten_nested_list([x.early_cutoff for x in eou_metrics]) + + eou_avg_num_early_cutoff = len(eou_early_cutoff) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + if len(eou_latency) == 0: + eou_latency = [0.0] + if len(eou_early_cutoff) == 0: + eou_early_cutoff = [0.0] + + eou_missing = [x.missing for x in eou_metrics] + + metrics = {} + eou_latency_metrics = get_percentiles(eou_latency, target_percentiles, tag='latency') + eou_early_cutoff_metrics = get_percentiles(eou_early_cutoff, target_percentiles, tag='early_cutoff') + + metrics.update(eou_latency_metrics) + metrics.update(eou_early_cutoff_metrics) + + metrics['early_cutoff_rate'] = eou_avg_num_early_cutoff + metrics['miss_rate'] = sum(eou_missing) / num_eou_utterances if num_eou_utterances > 0 else 0.0 + + return metrics diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 78d0b28fc8d4..f7838aa35d62 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -565,6 +565,8 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # 2.a. Noise mixing. if config.noise_path is not None: noise = guess_parse_cutset(config.noise_path) + # make sure the noise is resampled to the same sample rate as the audio cuts + noise = noise.resample(config.sample_rate) cuts = cuts.mix( cuts=noise, snr=tuple(config.noise_snr), diff --git a/scripts/asr_eou/add_eob_labels.py b/scripts/asr_eou/add_eob_labels.py new file mode 100644 index 000000000000..f1ff061c1b45 --- /dev/null +++ b/scripts/asr_eou/add_eob_labels.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A script to add EOB labels to a manifest file. + +Example usage: + +```bash +python add_eob_labels.py /path/to/manifest/dir +``` +where output will be saved in the same directory with `-eob` suffix added to the filename. +""" + +import argparse +import json +from pathlib import Path +from string import punctuation + +from tqdm import tqdm + +parser = argparse.ArgumentParser(description="Add `is_backchannel` labels to manifest files.") +parser.add_argument( + "input_manifest", + type=str, + help="Path to the input manifest file to be cleaned.", +) +parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Path to the output manifest file after cleaning.", +) +parser.add_argument( + "-p", + "--pattern", + type=str, + default="*.json", + help="Pattern to match files in the input directory.", +) + + +def read_manifest(manifest_path): + manifest = [] + with open(manifest_path, 'r') as f: + for line in f.readlines(): + line = line.strip() + if line: + manifest.append(json.loads(line)) + return manifest + + +def write_manifest(manifest_path, manifest): + with open(manifest_path, 'w') as f: + for item in manifest: + f.write(json.dumps(item) + '\n') + + +def clean_text(text): + text = text.translate(str.maketrans('', '', punctuation)).lower().strip() + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = ''.join([c for c in text if c in valid_chars or c.isspace() or c == "'"]) + return " ".join(text.split()).strip() + + +backchannel_phrases = [ + 'absolutely', + 'ah', + 'all right', + 'alright', + 'but yeah', + 'definitely', + 'exactly', + 'go ahead', + 'good', + 'great', + 'great thanks', + 'ha ha', + 'hi', + 'i know', + 'i know right', + 'i see', + 'indeed', + 'interesting', + 'mhmm', + 'mhmm mhmm', + 'mhmm right', + 'mhmm yeah', + 'mhmm yes', + 'nice', + 'of course', + 'oh', + 'oh dear', + 'oh man', + 'oh okay', + 'oh wow', + 'oh yes', + 'ok', + 'ok thanks', + 'okay', + 'okay okay', + 'okay thanks', + 'perfect', + 'really', + 'right', + 'right exactly', + 'right right', + 'right yeah', + 'so yeah', + 'sounds good', + 'sure', + 'thank you', + 'thanks', + "that's awesome", + 'thats right', + 'thats true', + 'true', + 'uh-huh', + 'uh-huh yeah', + 'uhhuh', + 'um-humm', + 'well', + 'what', + 'wow', + 'yeah', + 'yeah i know', + 'yeah i see', + 'yeah mhmm', + 'yeah okay', + 'yeah right', + 'yeah uh-huh', + 'yeah yeah', + 'yep', + 'yes', + 'yes please', + 'yes yes', + 'you know', + "you're right", +] + +backchannel_phrases_nopc = [clean_text(phrase) for phrase in backchannel_phrases] + + +def check_if_backchannel(text): + """ + Check if the text is a backchannel phrase. + """ + # Remove punctuation and convert to lowercase + text = clean_text(text) + # Check if the text is in the list of backchannel phrases + return text in backchannel_phrases_nopc + + +def add_eob_labels(manifest_path): + """ + Add EOB labels to a manifest file. + Args: + manifest_path: Path to the manifest file. + + Returns: + manifest: List of dictionaries with the EOB label added. + num_eob: Number of EOB labels added. + """ + num_eob = 0 + manifest = read_manifest(manifest_path) + for i, item in enumerate(manifest): + text = item['text'] + # Check if the text is a backchannel phrase + is_backchannel = check_if_backchannel(text) + # Add the EOB label to the text + if is_backchannel: + item['is_backchannel'] = True + num_eob += 1 + else: + item['is_backchannel'] = False + manifest[i] = item + return manifest, num_eob + + +def main(): + args = parser.parse_args() + input_manifest = Path(args.input_manifest) + + if input_manifest.is_dir(): + manifest_list = list(input_manifest.glob(args.pattern)) + if not manifest_list: + raise ValueError(f"No files found in {input_manifest} matching pattern `{args.pattern}`") + else: + manifest_list = [input_manifest] + + if args.output is None: + output_dir = input_manifest if input_manifest.is_dir() else input_manifest.parent + else: + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + + total_num_eob = 0 + print(f"Processing {len(manifest_list)} manifest files...") + for manifest_path in tqdm(manifest_list, total=len(manifest_list)): + output_file = output_dir / f"{manifest_path.stem}-eob.json" + new_manifest, num_eob = add_eob_labels(manifest_path) + total_num_eob += num_eob + write_manifest(output_file, new_manifest) + print(f"Processed {manifest_path} and saved to {output_file}. Number of EOB labels added: {num_eob}") + + print(f"Total number of EOB labels added: {total_num_eob}") + + +if __name__ == "__main__": + main() diff --git a/scripts/asr_eou/clean_manifest.py b/scripts/asr_eou/clean_manifest.py new file mode 100644 index 000000000000..bf6723cda596 --- /dev/null +++ b/scripts/asr_eou/clean_manifest.py @@ -0,0 +1,648 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A rule-based text cleaning script for preparing text for ASR-EOU model training. + +Example usage: + +```bash +python clean_manifest.py \ + /path/to/manifest/dir \ + -o /path/to/output/dir +``` + +""" + +import argparse +import re +import unicodedata +from pathlib import Path +from string import punctuation + +import dateutil.parser as date_parser +from num2words import num2words +from whisper_normalizer.english import EnglishTextNormalizer + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest + +punctuations = punctuation.replace("'", "") + +text_normalizer = EnglishTextNormalizer() + +parser = argparse.ArgumentParser(description="Clean manifest file") +parser.add_argument( + "input_manifest", + type=str, + help="Path to the input manifest file to be cleaned.", +) +parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Path to the output manifest file after cleaning.", +) +parser.add_argument( + "-lower", + "--lowercase", + type=bool, + default=False, + help="Whether to convert the text to lowercase.", +) +parser.add_argument( + "-drop", + "--remove_punc", + type=bool, + default=False, + help="Whether to remove punctuation from the text.", +) +parser.add_argument( + "--normalize", + type=bool, + default=False, + help="Whether to normalize the text using Whisper text normalizer.", +) +parser.add_argument( + "-n2w", + "--replace_numbers", + type=bool, + default=True, + help="Whether to replace numbers with words.", +) +parser.add_argument( + "-p", + "--pattern", + type=str, + default="**/*.json", + help="Pattern to match files in the input directory.", +) +parser.add_argument( + "-t", + "--text_field", + type=str, + default="text", + help="Field in the manifest to clean. Default is 'text'.", +) +parser.add_argument( + "--auto_pc", + action="store_true", + help="If set, will add auto capitalization and punctuation at the end of the text.", +) +parser.add_argument( + "--format", + default="asr", + choices=["asr", "conv"], + help="Format of the manifest. Default is 'asr'.", +) +parser.add_argument( + "--keep_name", + action="store_true", + help="If set, will keep the original name of the manifest file.", +) + +# Spoken representations + +MONTHS = [ + "", + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", +] + +ORDINALS = { + 1: "first", + 2: "second", + 3: "third", + 4: "fourth", + 5: "fifth", + 6: "sixth", + 7: "seventh", + 8: "eighth", + 9: "ninth", + 10: "tenth", + 11: "eleventh", + 12: "twelfth", + 13: "thirteenth", + 14: "fourteenth", + 15: "fifteenth", + 16: "sixteenth", + 17: "seventeenth", + 18: "eighteenth", + 19: "nineteenth", + 20: "twentieth", + 21: "twenty first", + 22: "twenty second", + 23: "twenty third", + 24: "twenty fourth", + 25: "twenty fifth", + 26: "twenty sixth", + 27: "twenty seventh", + 28: "twenty eighth", + 29: "twenty ninth", + 30: "thirtieth", + 31: "thirty first", +} + + +def speak_year(year: int) -> str: + if 2000 <= year <= 2099: + return f"twenty {speak_number(year % 100)}" + elif 1900 <= year <= 1999: + return f"nineteen {speak_number(year % 100)}" + else: + return str(year) + + +def speak_number(n: int) -> str: + num_words = { + 0: "zero", + 1: "one", + 2: "two", + 3: "three", + 4: "four", + 5: "five", + 6: "six", + 7: "seven", + 8: "eight", + 9: "nine", + 10: "ten", + 11: "eleven", + 12: "twelve", + 13: "thirteen", + 14: "fourteen", + 15: "fifteen", + 16: "sixteen", + 17: "seventeen", + 18: "eighteen", + 19: "nineteen", + 20: "twenty", + 30: "thirty", + 40: "forty", + 50: "fifty", + 60: "sixty", + 70: "seventy", + 80: "eighty", + 90: "ninety", + } + if n <= 20: + return num_words[n] + elif n < 100: + tens, ones = divmod(n, 10) + return f"{num_words[tens * 10]} {num_words[ones]}" if ones else num_words[tens * 10] + else: + return str(n) + + +def parse_with_auto_dayfirst(date_str: str): + try: + # Try both ways + parsed_us = date_parser.parse(date_str, dayfirst=False) + parsed_eu = date_parser.parse(date_str, dayfirst=True) + + # If one of the parses clearly makes more sense, return it + if parsed_us.month > 12: + return parsed_eu + if parsed_eu.month > 12: + return parsed_us + + # If day is greater than 12, it's probably day-first + if parsed_us.day > 12 and parsed_eu.day <= 12: + return parsed_eu + elif parsed_eu.day > 12 and parsed_us.day <= 12: + return parsed_us + + # Default fallback (assumes US style) + return parsed_us + except Exception: + return None + + +def date_to_spoken_string(date_str: str) -> str: + parsed = parse_with_auto_dayfirst(date_str) + if not parsed: + return None + + month = MONTHS[parsed.month] + day = ORDINALS[parsed.day] + spoken = f"{month} {day} {speak_year(parsed.year)}" + + return spoken + + +def replace_dates_in_text(text: str) -> str: + # Regex pattern to match common date formats like: + # 5/22, 05/22/2025, 22/05/2025, 2025-05-22 + date_pattern = r'\b(?:\d{1,4}[-/])?\d{1,2}[-/]\d{1,4}\b' + + def replace_match(match): + date_str = match.group(0) + spoken = date_to_spoken_string(date_str) + return spoken if spoken else date_str + + return re.sub(date_pattern, replace_match, text) + + +def convert_to_spoken(text: str) -> str: + + text = replace_dates_in_text(text) # Convert dates to spoken form + + # Mapping of metric units to spoken forms + unit_map = { + "kg": "kilograms", + "g": "grams", + "mg": "milligrams", + "l": "liters", + "ml": "milliliters", + "cm": "centimeters", + "mm": "millimeters", + "m": "meters", + "km": "kilometers", + "°c": "degrees celsius", + "°f": "degrees fahrenheit", + "oz": "ounces", + "lb": "pounds", + "lbs": "pounds", + } + + # Replace metric units like "12kg" or "5 ml" + def replace_metric(match): + number = match.group(1) + unit = match.group(2).lower() + spoken_unit = unit_map.get(unit, unit) + return f"{number} {spoken_unit}" + + # Replace time like "5am" or "6PM" + def replace_ampm(match): + hour = match.group(1) + meridiem = match.group(2).lower() + return f"{hour} {'a m' if meridiem == 'am' else 'p m'}" + + # Replace time like "1:30pm" + def replace_colon_time(match): + hour = match.group(1) + minute = match.group(2) + meridiem = match.group(3).lower() + return f"{hour} {minute} {'a m' if meridiem == 'am' else 'p m'}" + + # Convert feet and inches like 5'11" to "5 feet 11 inches" + def replace_feet_inches(match): + feet = match.group(1) + inches = match.group(2) + return f"{feet} feet {inches} inches" + + # Convert just feet (e.g., 6') to "6 feet" + def replace_feet_only(match): + feet = match.group(1) + return f"{feet} feet" + + # Convert just inches (e.g., 10") to "10 inches" + def replace_inches_only(match): + inches = match.group(1) + return f"{inches} inches" + + # Apply replacements + # First: time with colon (e.g., 1:30pm) + text = re.sub(r'\b(\d{1,2}):(\d{2})(am|pm)\b', replace_colon_time, text, flags=re.IGNORECASE) + + # Then: basic am/pm (e.g., 5am) + text = re.sub(r'\b(\d{1,2})(am|pm)\b', replace_ampm, text, flags=re.IGNORECASE) + + # Then: replace 1st, 2nd, 3rd, etc + text = text.replace("1st", "first") + text = text.replace("2nd", "second") + text = text.replace("3rd", "third") + text = text.replace("@", " at ") + + # Finally: metric units + text = re.sub( + r'\b(\d+(?:\.\d+)?)\s?(kg|g|mg|l|ml|cm|mm|m|km|°c|°f|oz|lbs?|LB|LBS?)\b', + replace_metric, + text, + flags=re.IGNORECASE, + ) + text = re.sub(r'\b(\d+)\'(\d+)"', replace_feet_inches, text) # e.g., 5'11" + text = re.sub(r'\b(\d+)\'', replace_feet_only, text) # e.g., 6' + text = re.sub(r'(\d+)"', replace_inches_only, text) # e.g., 10" + + return text + + +def replace_numbers_with_words(text): + def convert_number(match): + num_str = match.group() + original = num_str + + # Remove dollar sign + is_dollar = False + if num_str.startswith('$'): + is_dollar = True + num_str = num_str[1:] + elif num_str.endswith('$'): + is_dollar = True + num_str = num_str[:-1] + + # Remove commas + num_str = num_str.replace(',', '') + + try: + if '.' in num_str: + # Convert decimal number + integer_part, decimal_part = num_str.split('.') + words = num2words(int(integer_part)) + ' point ' + ' '.join(num2words(int(d)) for d in decimal_part) + else: + words = num2words(int(num_str)) + if is_dollar: + words += ' dollars' + return words + " " + except Exception: + return original # Return original if conversion fails + + # Pattern matches: $3,000 or 3,000.45 or 1234 + pattern = re.compile(r'\$?\d{1,3}(?:,\d{3})*(?:\.\d+)?|\$?\d+(?:\.\d+)?') + result = pattern.sub(convert_number, text) + result = result.replace("$", " dollars ") # Handle dollar sign separately + + def merge_th(text: str) -> str: + # merge th with the preceding digit + candidates = ["four th ", "five th ", "six th ", "seven th ", "eight th ", "nine th "] + for key in candidates: + if key in text: + if "five" in key: + target = "fifth " + else: + target = f"{key.split(' ')[0]}th " + text = text.replace(key, target) + elif text.endswith(key.strip()): + if "five" in key: + target = "fifth" + else: + target = f"{key.split(' ')[0]}th" + text = text.replace(key.strip(), target) + return text + + result = merge_th(result) + result = " ".join(result.split()) # Remove extra spaces + return result + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def drop_punctuations(text: str) -> str: + """ + Clean the text by removing invalid characters and converting to lowercase. + + :param text: Input text. + :return: Cleaned text. + """ + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = text.lower() + text = unicode_to_ascii(text) + text = text.replace(":", " ") + text = text.replace("-", " ") + text = ''.join([c for c in text if c in valid_chars or c.isspace()]) + text = ' '.join(text.split()).strip() + return text + + +def clean_label(_str: str) -> str: + """ + Remove unauthorized characters in a string, lower it and remove unneeded spaces + """ + # replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→'] + replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”'] + replace_with_apos = [char for char in '‘’ʻ‘’‘'] + ["\u2019"] + _str = _str.strip() + for i in replace_with_blank: + _str = _str.replace(i, "") + for i in replace_with_apos: + _str = _str.replace(i, "'") + + text = _str + text = text.replace("\u2103", "celsius") + text = text.replace("\u2109", "fahrenheit") + text = text.replace("\u00b0", "degrees") + text = text.replace("\u2019", "'") + text = text.replace("\\", ".") + text = text.replace("\n", " ") + text = text.replace("\r", " ") + text = text.replace("\t", " ") + + ret = " ".join(text.split()) + return ret + + +def ends_with_punctuation(s: str) -> bool: + # Strip trailing whitespace + s = s.rstrip() + + # consider this set to be punctuation that's acceptable to end a sentence with + puncturation_chars = [",", ".", ":", ";", "?", "!", "-", "—", "–", "…"] + + # If string is empty after stripping, return False + if not s: + return False + + # Get the last character + last_char = s[-1] + + # Return True if the last character is punctuation, otherwise False + return last_char in puncturation_chars + + +def add_period_if_needed(text: str) -> str: + """ + Add a period at the end of the text if it does not already end with one. + """ + if not ends_with_punctuation(text): + text += "." + return text.strip() + + +def capitalize_self_i(text): + # Replace standalone lowercase "i" with "I" + # Handles "i", "i.", "i?", "i'll", "i'm", etc. + return re.sub(r'\b(i)(?=[\s.,!?;:\'\"-]|$)', r'I', text) + + +def add_space_after_punctuation(text): + # Add a space after punctuation if it's not already followed by one or by the end of the string + return re.sub(r'([,\.?;:])(?=\S)', r'\1 ', text) + + +def add_auto_capitalization(text): + if text.lower() != text: + # If the text is not all lowercase, we assume it has some capitalization + return text + + # Remove space before punctuation (.,!?;:) + text = re.sub(r'\s+([.,!?;:])', r'\1', text) + + # Capitalize the first letter of each sentence + def capitalize_sentences(match): + return match.group(1) + match.group(2).upper() + + # Ensure first character is capitalized + text = text.strip() + if text: + text = text[0].upper() + text[1:] + + text = capitalize_self_i(text) + text = add_space_after_punctuation(text) + # Capitalize after sentence-ending punctuation followed by space(s) + text = re.sub(r'([.!?]\s+)([a-z])', capitalize_sentences, text) + return text + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def clean_text(text: str, args) -> str: + """ + Clean the text based on the provided arguments. + """ + text = unicode_to_ascii(text) + if args.normalize: + text = text_normalizer(text) + if args.replace_numbers: + text = convert_to_spoken(text) + text = replace_numbers_with_words(text) + if args.lowercase: + text = text.lower() + if args.remove_punc: + text = text.replace("-", " ") + text = text.replace("_", " ") + text = text.translate(str.maketrans("", "", punctuations)) + text = drop_punctuations(text) + if args.auto_pc: + text = add_auto_capitalization(text) + return clean_label(text) + + +def clean_asr_manifest(manifest, text_field, args): + for i, item in enumerate(manifest): + text = str(item[text_field]) + manifest[i][f"origin_{text_field}"] = text + manifest[i][text_field] = clean_text(text, args) + return manifest + + +def clean_conv_manifest(manifest, text_field, args): + new_manifest = [] + for i, item in enumerate(manifest): + conversations = [] + for turn in item["conversations"]: + conversations.append( + { + "role": turn["role"], + "value": clean_text(turn["value"], args), + "type": turn.get("type", "text"), + } + ) + item["conversations"] = conversations + new_manifest.append(item) + return manifest + + +def main(args): + text_field = args.text_field + manifest_files = Path(args.input_manifest) + if manifest_files.is_dir(): + manifest_files = list(manifest_files.glob(args.pattern)) + elif manifest_files.is_file(): + manifest_files = [manifest_files] + else: + raise ValueError(f"Invalid input manifest path: {args.input_manifest}") + + for manifest_file in manifest_files: + print(f"Processing manifest file: {manifest_file}") + postfix = "-cleaned" + postfix += "_norm" if args.normalize else "" + postfix += "_n2w" if args.replace_numbers else "" + if args.lowercase and args.remove_punc: + postfix += "_noPC" + else: + postfix += "_lc" if args.lowercase else "" + postfix += "_np" if args.remove_punc else "" + postfix += "_aPC" if args.auto_pc else "" + + output_manifest = manifest_file.with_name(f"{manifest_file.stem}{postfix}{manifest_file.suffix}") + + if args.output: + if args.output.endswith(".json"): + if len(manifest_files) > 1: + raise ValueError("Output path must be a directory when processing multiple manifest files.") + output_manifest = Path(args.output) + else: + output_dir = Path(args.output) + output_dir.mkdir(parents=True, exist_ok=True) + if args.keep_name: + output_manifest = output_dir / manifest_file.name + else: + output_manifest = output_dir / output_manifest.name + + manifest = read_manifest(str(manifest_file)) + + if args.format == "asr": + manifest = clean_asr_manifest(manifest, text_field, args) + elif args.format == "conv": + manifest = clean_conv_manifest(manifest, text_field, args) + else: + raise ValueError(f"Unsupported manifest format: {args.format}") + + write_manifest(str(output_manifest), manifest) + print(f"Cleaned manifest saved to {output_manifest}") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args) diff --git a/scripts/asr_eou/conf/data.yaml b/scripts/asr_eou/conf/data.yaml new file mode 100644 index 000000000000..c27bf129fa1e --- /dev/null +++ b/scripts/asr_eou/conf/data.yaml @@ -0,0 +1,46 @@ + +output_dir: ??? + +data: + pattern: "*.json" + manifest_filepath: ??? + tarred_audio_filepaths: null + sample_rate: 16000 + max_duration: 30 # you may need to update it for your dataset + min_duration: 0.1 + batch_duration: 300 # you may disable batch_duration by setting it to `null` + batch_size: null + shuffle: false + seed: 42 + num_workers: 8 + pin_memory: true + quadratic_duration: 30 + num_buckets: 30 + num_cuts_for_bins_estimate: 10000 + bucket_buffer_size: 10000 + shuffle_buffer_size: 10000 + + random_padding: + prob: 1.0 + min_pad_duration: 0.0 # minimum duration of pre/post padding in seconds + max_pad_duration: 3.0 # maximum duration of pre/post padding in seconds + max_total_duration: 40.0 # maximum total duration of the padded audio in seconds + pad_distribution: 'constant' # distribution of padding duration, 'uniform' or 'normal' or 'constant' + pre_pad_duration: 0.2 + post_pad_duration: 3.0 + + augmentor: + white_noise: + prob: 0.0 + min_level: -90 + max_level: -40 + gain: + prob: 0.0 + min_gain_dbfs: -10.0 + max_gain_dbfs: 10.0 + noise: + prob: 1.0 + manifest_path: ??? + min_snr_db: 0 + max_snr_db: 20 + max_gain_db: 300.0 \ No newline at end of file diff --git a/scripts/asr_eou/eval_eou_metrics.py b/scripts/asr_eou/eval_eou_metrics.py new file mode 100644 index 000000000000..9919f73af504 --- /dev/null +++ b/scripts/asr_eou/eval_eou_metrics.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script calculates the EOU metrics using predictions and references in SegLST format. + +Example usage: + +The PREDICTION_ROOT and REFERENCE_ROOT directories should have the following structure: + +: +->dataset1/ + -> sample1.json + -> sample2.json +->dataset2/ + -> sample1.json + -> sample2.json + +: +->dataset1/ + -> sample1.json + -> sample2.json +->dataset2/ + -> sample1.json + -> sample2.json + + +each sample.json should contain a list of dictionaries with the following fields: +{ + "session_id": str, + "start_time": float, # start time in seconds + "end_time": float, # end time in seconds + "words": str, # transcription of the utterance + "audio_filepath": str, # only in prediction + "eou_prob": float, # only in prediction, probability of EOU in range [0.1] + "eou_pred": bool, # only in prediction + "full_text": str, # only in prediction, which is the full transcription up to the end_time +} + +```bash +python eval_eou_metrics.py \ + --prediction $PREDICTION_ROOT \ + --reference $REFERENCE_ROOT \ + --multiple +``` +""" + + +import argparse +import json +from pathlib import Path +from typing import List + +from nemo.collections.asr.parts.utils.eou_utils import EOUResult, aggregate_eou_metrics, evaluate_eou + +parser = argparse.ArgumentParser(description="Evaluate end of utterance predictions against reference labels.") +parser.add_argument( + "-p", + "--prediction", + type=str, + required=True, + help="Path to the directory containing the predictions.", +) +parser.add_argument( + "-r", + "--reference", + type=str, + required=True, + help="Path to the directory containing the groundtruth.", +) +parser.add_argument( + "--eob", + action="store_true", + help="Whether to evaluate end of backchannel predictions.", +) +parser.add_argument( + "--ignore_eob", + action="store_true", + help="Whether to ignore end of backchannel predictions.", +) +parser.add_argument( + "--multiple", + action="store_true", + help="Whether to evaluate multiple datasets.", +) + + +def load_segLST(directory: str, use_eob: bool = False, ignore_eob: bool = False) -> dict: + json_files = list(Path(directory).glob("*.json")) + segLST = {} + for json_file in json_files: + key = json_file.stem + with open(json_file, 'r') as f: + data = json.load(f) + assert isinstance(data, list), f"Data in {json_file} is not a list." + if not ignore_eob: + # get the data with the correct eob label + data = [x for x in data if (x.get("is_backchannel", False) == use_eob)] + segLST[key] = data + return segLST + + +def evaluate_eou_predictions( + prediction_dir: str, reference_dir: str, use_eob: bool = False, ignore_eob: bool = False +) -> List[EOUResult]: + prediction_segLST = load_segLST(prediction_dir, use_eob, ignore_eob) + reference_segLST = load_segLST(reference_dir, use_eob, ignore_eob) + + eou_metrics = [] + for key, reference in reference_segLST.items(): + if key not in prediction_segLST: + raise ValueError(f"Key {key} in reference not found in predictions.") + prediction = prediction_segLST[key] + eou_result = evaluate_eou( + prediction=prediction, reference=reference, threshold=None, collar=0.0, do_sorting=True + ) + eou_metrics.append(eou_result) + + results = aggregate_eou_metrics(eou_metrics) + + # add prefix to the keys of the results + prefix = Path(reference_dir).stem + prefix += "_eob" if use_eob else "_eou" + results = {f"{prefix}_{k}": v for k, v in results.items()} + + return results + + +if __name__ == "__main__": + args = parser.parse_args() + + prediction_dir = Path(args.prediction) + reference_dir = Path(args.reference) + + if not prediction_dir.is_dir(): + raise ValueError(f"Prediction directory {prediction_dir} does not exist or is not a directory.") + if not reference_dir.is_dir(): + raise ValueError(f"Reference directory {reference_dir} does not exist or is not a directory.") + + if args.multiple: + # get all subdirectories in the prediction and reference directories + prediction_dirs = sorted([x for x in prediction_dir.glob("*/") if x.is_dir()]) + reference_dirs = sorted([x for x in reference_dir.glob("*/") if x.is_dir()]) + if len(prediction_dirs) != len(reference_dirs): + raise ValueError( + f"Number of prediction directories {len(prediction_dirs)} must match number of reference directories {len(reference_dirs)}." + ) + else: + prediction_dirs = [prediction_dir] + reference_dirs = [reference_dir] + + for ref_dir, pred_dir in zip(reference_dirs, prediction_dirs): + if args.multiple and ref_dir.stem != pred_dir.stem: + raise ValueError( + f"Reference directory {ref_dir} and prediction directory {pred_dir} must have the same name." + ) + results = evaluate_eou_predictions( + prediction_dir=str(pred_dir), reference_dir=str(ref_dir), use_eob=args.eob, ignore_eob=args.ignore_eob + ) + # Print the results + print("==========================================") + print(f"Evaluation Results for: {pred_dir} against {ref_dir}") + for key, value in results.items(): + print(f"{key}: {value:.4f}") + print("==========================================") diff --git a/scripts/asr_eou/generate_noisy_eval_data.py b/scripts/asr_eou/generate_noisy_eval_data.py new file mode 100644 index 000000000000..19a9dcf1cd3c --- /dev/null +++ b/scripts/asr_eou/generate_noisy_eval_data.py @@ -0,0 +1,224 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script is used to generate noisy evaluation data for ASR and end of utterance detection. + +Example usage with a single manifest input: +python generate_noisy_eval_data.py \ + --config-path conf/ \ + --config-name data \ + output_dir=/path/to/output \ + data.manifest_filepath=/path/to/manifest.json \ + data.seed=42 \ + data.noise.manifest_path /path/to/noise_manifest.json + + +Example usage with multiple manifests matching a pattern: +python generate_noisy_eval_data.py \ + --config-path conf/ \ + --config-name data \ + output_dir=/path/to/output/dir \ + data.manifest_filepath=/path/to/manifest/dir/ \ + data.pattern="*.json" \ + data.seed=42 \ + data.noise.manifest_path /path/to/noise_manifest.json + +""" + +from copy import deepcopy +from pathlib import Path +from shutil import rmtree + +import librosa +import lightning.pytorch as pl +import numpy as np +import soundfile as sf +import torch +import yaml +from lhotse.cut import MixedCut +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import LhotseSpeechToTextBpeEOUDataset +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.parts.preprocessing import parsers +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +@hydra_runner(config_path="conf/", config_name="data") +def main(cfg: DictConfig): + """ + Generate noisy evaluation data for ASR and end of utterance detection. + + Args: + cfg: DictConfig object containing the configuration. + """ + # Seed everything for reproducibility + seed = cfg.data.get('seed', None) + if seed is None: + seed = np.random.randint(0, 2**32 - 1) + logging.info(f'No seed provided, using random seed: {seed}') + logging.info(f'Setting random seed to {seed}') + with open_dict(cfg): + cfg.data.seed = seed + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + pl.seed_everything(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # Patch data config + with open_dict(cfg.data): + cfg.data.force_finite = True + cfg.data.force_map_dataset = True + cfg.data.shuffle = False + cfg.data.check_tokenizer = False # No need to check tokenizer in LhotseSpeechToTextBpeEOUDataset + + # Make output directory + output_dir = Path(cfg.output_dir) + if output_dir.exists() and cfg.get('overwrite', False): + logging.info(f'Removing existing output directory: {output_dir}') + rmtree(output_dir) + if not output_dir.exists(): + logging.info(f'Creating output directory: {output_dir}') + output_dir.mkdir(parents=True, exist_ok=True) + + # Dump the config to the output directory + config = OmegaConf.to_container(cfg, resolve=True) + with open(output_dir / 'config.yaml', 'w') as f: + yaml.dump(config, f) + logging.info(f'Config dumped to {output_dir / "config.yaml"}') + + if isinstance(cfg.data.manifest_filepath, (list, ListConfig)): + manifest_list = [Path(x) for x in cfg.data.manifest_filepath] + else: + input_manifest_file = Path(cfg.data.manifest_filepath) + if input_manifest_file.is_dir(): + pattern = cfg.data.get('pattern', '*.json') + manifest_list = list(input_manifest_file.glob(pattern)) + if not manifest_list: + raise ValueError(f"No files found in {input_manifest_file} matching pattern `{pattern}`") + else: + manifest_list = [Path(x) for x in str(input_manifest_file).split(",")] + + logging.info(f'Found {len(manifest_list)} manifest files to process...') + + for i, manifest_file in enumerate(manifest_list): + logging.info(f'[{i+1}/{len(manifest_list)}] Processing {manifest_file}...') + data_cfg = deepcopy(cfg.data) + data_cfg.manifest_filepath = str(manifest_file) + process_manifest(data_cfg, output_dir) + + +def process_manifest(data_cfg: DictConfig, output_dir: Path): + """ + Process a manifest file and generate noisy evaluation data. + + Args: + data_cfg: Configuration. + output_dir: Output directory. + """ + # Load the input manifest + input_manifest = read_manifest(data_cfg.manifest_filepath) + logging.info(f'Found {len(input_manifest)} items in input manifest: {data_cfg.manifest_filepath}') + manifest_parent_dir = Path(data_cfg.manifest_filepath).parent + if Path(input_manifest[0]["audio_filepath"]).is_absolute(): + output_audio_dir = output_dir / 'wav' + flatten_audio_path = True + else: + output_audio_dir = output_dir + flatten_audio_path = False + + if "random_padding" in data_cfg and data_cfg.random_padding.pad_distribution == "constant": + is_constant_padding = True + pre_pad_dur = data_cfg.random_padding.pre_pad_duration + else: + is_constant_padding = False + pre_pad_dur = None + + # Load the dataset + tokenizer = parsers.make_parser() # dummy tokenizer + dataset = LhotseSpeechToTextBpeEOUDataset(cfg=data_cfg, tokenizer=tokenizer, return_cuts=True) + + dataloader = get_lhotse_dataloader_from_config( + config=data_cfg, + global_rank=0, + world_size=1, + dataset=dataset, + tokenizer=tokenizer, + ) + + # Generate noisy evaluation data + manifest = [] + for i, batch in enumerate(tqdm(dataloader, desc="Generating noisy evaluation data")): + audio_batch, audio_len_batch, cuts_batch = batch + audio_batch = audio_batch.cpu().numpy() + audio_len_batch = audio_len_batch.cpu().numpy() + + for j in range(len(cuts_batch)): + cut = cuts_batch[j] + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + + manifest_item = {} + for k, v in cut.custom.items(): + if k == "dataloading_info": + continue + manifest_item[k] = v + audio = audio_batch[j][: audio_len_batch[j]] + audio_file = cut.recording.sources[0].source + + if flatten_audio_path: + output_audio_file = output_audio_dir / str(audio_file).replace('/', '_')[:255] # type: Path + else: + output_audio_file = output_audio_dir / Path(audio_file).relative_to(manifest_parent_dir) # type: Path + + output_audio_file.parent.mkdir(parents=True, exist_ok=True) + sf.write(output_audio_file, audio, dataset.sample_rate) + + manifest_item["audio_filepath"] = str(output_audio_file.relative_to(output_audio_dir)) + manifest_item["offset"] = 0 + manifest_item["duration"] = audio.shape[0] / dataset.sample_rate + + if is_constant_padding: + # Adjust the sou_time and eou_time for constant padding + if 'sou_time' in manifest_item and 'eou_time' in manifest_item: + if not isinstance(manifest_item['sou_time'], list): + manifest_item['sou_time'] = manifest_item['sou_time'] + pre_pad_dur + manifest_item['eou_time'] = manifest_item['eou_time'] + pre_pad_dur + else: + manifest_item['sou_time'] = [x + pre_pad_dur for x in manifest_item['sou_time']] + manifest_item['eou_time'] = [x + pre_pad_dur for x in manifest_item['eou_time']] + else: + # add sou_time and eou_time to the manifest item + manifest_item['sou_time'] = pre_pad_dur + manifest_item['eou_time'] = pre_pad_dur + librosa.get_duration(filename=audio_file) + + manifest.append(manifest_item) + + # Write the output manifest + output_manifest_file = output_dir / Path(data_cfg.manifest_filepath).name + write_manifest(output_manifest_file, manifest) + logging.info(f'Output manifest written to {output_manifest_file}') + + +if __name__ == "__main__": + main() diff --git a/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py b/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py new file mode 100644 index 000000000000..c85af1516866 --- /dev/null +++ b/scripts/asr_eou/tokenizers/add_special_tokens_to_sentencepiece.py @@ -0,0 +1,214 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +import logging +import sys +import tempfile +from argparse import ArgumentParser +from pathlib import Path + +import sentencepiece as spm + +from nemo.collections.asr.data.audio_to_eou_label_lhotse import EOB_STRING, EOU_STRING +from nemo.core.connectors.save_restore_connector import SaveRestoreConnector + +try: + import sentencepiece_model_pb2 as spt +except (ImportError, ModuleNotFoundError): + raise Exception("Ensure that sentencepiece_model_pb2.py has been generated from the protoc compiler") + + +SPECIAL_TOKENS = [EOU_STRING, EOB_STRING] + +"""Utility to add special tokens to existing sentencepiece models. + +Generate sentencepiece_model_pb2.py in the directory of this script before running +To generate run `protoc --python_out=/scripts/asr_end_of_utterance/tokenizers sentencepiece_model.proto` +inside the src folder in sentencepiece repo +Refer: https://github.com/google/sentencepiece/issues/121 + +Usage: +python add_special_tokens_to_sentencepiece.py \ + --input_file your_model.nemo \ + --output_dir /path/to/new/tokenizer_dir/ +""" + + +parser = ArgumentParser(description="Add special tokens to sentencepiece model") +parser.add_argument( + "-i", + "--input_file", + type=str, + required=True, + help="Path to nemo model file, or sentencepiece model file", +) +parser.add_argument( + "-o", + "--output_dir", + type=str, + required=True, + help="Path to output directory for new tokenizer", +) +parser.add_argument( + "--tokens", + type=str, + nargs='+', + help="Special tokens to add to tokenizer", + default=SPECIAL_TOKENS, +) +parser.add_argument( + "--extract_only", + action="store_true", + help="Extract tokenizer without adding special tokens", +) + + +def extract_nemo_tokenizer(nemo_filepath: str, output_dir: Path) -> str: + """ + Extract a tokenizer from a Nemo file. + Args: + nemo_filepath: Path to the Nemo file. + output_dir: Path to the output directory. + + Returns: + tokenizer: Path to the tokenizer file. + """ + SaveRestoreConnector._unpack_nemo_file(path2file=nemo_filepath, out_folder=output_dir) + tokenizer = None + for file in Path(output_dir).glob("**/*"): + if file.is_file() and file.name.endswith("tokenizer.model"): + tokenizer = file + break + if tokenizer is None: + raise ValueError(f"Tokenizer not found in {output_dir}: {os.listdir(output_dir)}") + return str(tokenizer.absolute()) + + +def edit_spt_model(input_file, output_dir, tokens, is_userdefined, extract_only=False): + """ + Edit a sentencepiece model to add special tokens. + Args: + input_file: Path to the input sentencepiece model file. + output_dir: Path to the output directory. + tokens: List of special tokens to add. + is_userdefined: Whether the special tokens are user-defined. + extract_only: Whether to extract the tokenizer only. + """ + + if extract_only: + logging.info("Extracting tokenizer only, no special tokens will be added.") + + output_dir = Path(output_dir) + + if output_dir.exists(): + logging.warning(f"Output directory {output_dir} already exists. Overwriting it.") + + output_dir.mkdir(parents=True, exist_ok=True) + + output_file = str(output_dir / "tokenizer.model") + + token_type = 3 + if is_userdefined: + token_type = 4 + + model = spt.ModelProto() + with open(input_file, 'rb') as f: + model.ParseFromString(f.read()) + + if not extract_only: + for token in tokens: + piece = model.SentencePiece(piece=token, score=0.0, type=token_type) + if piece in model.pieces: + logging.error(f"Special Token '{token}' already exists in the input model!") + sys.exit(1) + model.pieces.append(piece) + + sp = spm.SentencePieceProcessor() + sp.LoadFromSerializedProto(model.SerializeToString()) + + if not extract_only: + try: + for token in tokens: + id = sp.piece_to_id(token) + logging.info(f"Created token '{token}' at ID {id}") + logging.info(f"New tokenizer vocab size: {sp.get_piece_size()}") + except Exception: + logging.error( + "Could not appropriately configure new tokenizer. Verify if the special tokens already exist." + ) + sys.exit(1) + + with open(output_file, 'wb') as outf: + outf.write(model.SerializeToString()) + logging.info(f"Created new tokenizer at: {output_file}") + + # Write the vocab to file + vocab_file = str(output_dir / "tokenizer.vocab") + with open(vocab_file, "w", encoding="utf-8") as f: + for i in range(sp.get_piece_size()): + piece = sp.id_to_piece(i) + score = sp.get_score(i) # Optional: only available if using newer SentencePiece versions + f.write(f"{piece}\t{score}\n") # Format follows the original vocab format + logging.info(f"Created new tokenizer vocab at: {vocab_file}") + + special_tokens = ["", "", "", ""] + special_tokens.extend(tokens) + vocab_txt_file = str(output_dir / "vocab.txt") + with open(vocab_txt_file, "w", encoding="utf-8") as f: + for i in range(sp.get_piece_size()): + piece = sp.id_to_piece(i) + if piece in special_tokens: + # skip special tokens + continue + token = piece[1:] if piece.startswith("▁") else f"##{piece}" + if len(token) > 0: + f.write(f"{token}\n") # Format follows the original vocab format + logging.info(f"Created new tokenizer vocab at: {vocab_txt_file}") + + +def inject_special_tokens(input_file, output_dir, tokens, is_userdefined=True, extract_only=False): + """ + Inject special tokens into a sentencepiece model. + NOTE: is_userdefined should be set to True in order for ASR model to work with the new special tokens properly. + + Args: + input_file: Path to the input sentencepiece model file. + output_dir: Path to the output directory. + tokens: List of special tokens to add. + is_userdefined: Whether the special tokens are user-defined. + extract_only: Whether to extract the tokenizer only. + """ + + if not os.path.exists(input_file): + raise ValueError(f"Input file {input_file} does not exist") + + with tempfile.TemporaryDirectory() as temp_dir: + # Check if input file is a Nemo file + if input_file.endswith(".nemo"): + input_file = extract_nemo_tokenizer(input_file, temp_dir) + logging.info(f"Extracted tokenizer from Nemo file: {input_file}") + else: + input_file = os.path.abspath(input_file) + logging.info(f"Using input file: {input_file}") + + edit_spt_model(input_file, output_dir, tokens, is_userdefined, extract_only=extract_only) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + args = parser.parse_args() + inject_special_tokens(args.input_file, args.output_dir, args.tokens, extract_only=args.extract_only) diff --git a/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py b/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py new file mode 100644 index 000000000000..cb97411349aa --- /dev/null +++ b/scripts/asr_eou/tokenizers/sentencepiece_model_pb2.py @@ -0,0 +1,1442 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: sentencepiece_model.proto + +import sys + +_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='sentencepiece_model.proto', + package='sentencepiece', + syntax='proto2', + serialized_options=_b('H\003'), + serialized_pb=_b( + '\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\xa4\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18seed_sentencepieces_file\x18\x36 \x01(\t:\x00\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03' + ), +) + + +_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor( + name='ModelType', + full_name='sentencepiece.TrainerSpec.ModelType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor(name='UNIGRAM', index=0, number=1, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='BPE', index=1, number=2, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='WORD', index=2, number=3, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='CHAR', index=3, number=4, serialized_options=None, type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=1553, + serialized_end=1606, +) +_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE) + +_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='sentencepiece.ModelProto.SentencePiece.Type', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor(name='NORMAL', index=0, number=1, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='UNKNOWN', index=1, number=2, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='CONTROL', index=2, number=3, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='USER_DEFINED', index=3, number=4, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='BYTE', index=4, number=6, serialized_options=None, type=None), + _descriptor.EnumValueDescriptor(name='UNUSED', index=5, number=5, serialized_options=None, type=None), + ], + containing_type=None, + serialized_options=None, + serialized_start=2359, + serialized_end=2443, +) +_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE) + + +_TRAINERSPEC = _descriptor.Descriptor( + name='TrainerSpec', + full_name='sentencepiece.TrainerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', + full_name='sentencepiece.TrainerSpec.input', + index=0, + number=1, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='input_format', + full_name='sentencepiece.TrainerSpec.input_format', + index=1, + number=7, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='model_prefix', + full_name='sentencepiece.TrainerSpec.model_prefix', + index=2, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='model_type', + full_name='sentencepiece.TrainerSpec.model_type', + index=3, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='vocab_size', + full_name='sentencepiece.TrainerSpec.vocab_size', + index=4, + number=4, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=8000, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='accept_language', + full_name='sentencepiece.TrainerSpec.accept_language', + index=5, + number=5, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='self_test_sample_size', + full_name='sentencepiece.TrainerSpec.self_test_sample_size', + index=6, + number=6, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='enable_differential_privacy', + full_name='sentencepiece.TrainerSpec.enable_differential_privacy', + index=7, + number=50, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='differential_privacy_noise_level', + full_name='sentencepiece.TrainerSpec.differential_privacy_noise_level', + index=8, + number=51, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='differential_privacy_clipping_threshold', + full_name='sentencepiece.TrainerSpec.differential_privacy_clipping_threshold', + index=9, + number=52, + type=4, + cpp_type=4, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='character_coverage', + full_name='sentencepiece.TrainerSpec.character_coverage', + index=10, + number=10, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0.9995), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='input_sentence_size', + full_name='sentencepiece.TrainerSpec.input_sentence_size', + index=11, + number=11, + type=4, + cpp_type=4, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='shuffle_input_sentence', + full_name='sentencepiece.TrainerSpec.shuffle_input_sentence', + index=12, + number=19, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='mining_sentence_size', + full_name='sentencepiece.TrainerSpec.mining_sentence_size', + index=13, + number=12, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=_b('\030\001'), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='training_sentence_size', + full_name='sentencepiece.TrainerSpec.training_sentence_size', + index=14, + number=13, + type=5, + cpp_type=1, + label=1, + has_default_value=False, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=_b('\030\001'), + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='seed_sentencepiece_size', + full_name='sentencepiece.TrainerSpec.seed_sentencepiece_size', + index=15, + number=14, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=1000000, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='shrinking_factor', + full_name='sentencepiece.TrainerSpec.shrinking_factor', + index=16, + number=15, + type=2, + cpp_type=6, + label=1, + has_default_value=True, + default_value=float(0.75), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='max_sentence_length', + full_name='sentencepiece.TrainerSpec.max_sentence_length', + index=17, + number=18, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=4192, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='num_threads', + full_name='sentencepiece.TrainerSpec.num_threads', + index=18, + number=16, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=16, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='num_sub_iterations', + full_name='sentencepiece.TrainerSpec.num_sub_iterations', + index=19, + number=17, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=2, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='max_sentencepiece_length', + full_name='sentencepiece.TrainerSpec.max_sentencepiece_length', + index=20, + number=20, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=16, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_unicode_script', + full_name='sentencepiece.TrainerSpec.split_by_unicode_script', + index=21, + number=21, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_number', + full_name='sentencepiece.TrainerSpec.split_by_number', + index=22, + number=23, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_by_whitespace', + full_name='sentencepiece.TrainerSpec.split_by_whitespace', + index=23, + number=22, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='treat_whitespace_as_suffix', + full_name='sentencepiece.TrainerSpec.treat_whitespace_as_suffix', + index=24, + number=24, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='allow_whitespace_only_pieces', + full_name='sentencepiece.TrainerSpec.allow_whitespace_only_pieces', + index=25, + number=26, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='split_digits', + full_name='sentencepiece.TrainerSpec.split_digits', + index=26, + number=25, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pretokenization_delimiter', + full_name='sentencepiece.TrainerSpec.pretokenization_delimiter', + index=27, + number=53, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='control_symbols', + full_name='sentencepiece.TrainerSpec.control_symbols', + index=28, + number=30, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='user_defined_symbols', + full_name='sentencepiece.TrainerSpec.user_defined_symbols', + index=29, + number=31, + type=9, + cpp_type=9, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='required_chars', + full_name='sentencepiece.TrainerSpec.required_chars', + index=30, + number=36, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='byte_fallback', + full_name='sentencepiece.TrainerSpec.byte_fallback', + index=31, + number=35, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='vocabulary_output_piece_score', + full_name='sentencepiece.TrainerSpec.vocabulary_output_piece_score', + index=32, + number=32, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='hard_vocab_limit', + full_name='sentencepiece.TrainerSpec.hard_vocab_limit', + index=33, + number=33, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='use_all_vocab', + full_name='sentencepiece.TrainerSpec.use_all_vocab', + index=34, + number=34, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_id', + full_name='sentencepiece.TrainerSpec.unk_id', + index=35, + number=40, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=0, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='bos_id', + full_name='sentencepiece.TrainerSpec.bos_id', + index=36, + number=41, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='eos_id', + full_name='sentencepiece.TrainerSpec.eos_id', + index=37, + number=42, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=2, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pad_id', + full_name='sentencepiece.TrainerSpec.pad_id', + index=38, + number=43, + type=5, + cpp_type=1, + label=1, + has_default_value=True, + default_value=-1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_piece', + full_name='sentencepiece.TrainerSpec.unk_piece', + index=39, + number=45, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='bos_piece', + full_name='sentencepiece.TrainerSpec.bos_piece', + index=40, + number=46, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='eos_piece', + full_name='sentencepiece.TrainerSpec.eos_piece', + index=41, + number=47, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='pad_piece', + full_name='sentencepiece.TrainerSpec.pad_piece', + index=42, + number=48, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='unk_surface', + full_name='sentencepiece.TrainerSpec.unk_surface', + index=43, + number=44, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b(" \342\201\207 ").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='train_extremely_large_corpus', + full_name='sentencepiece.TrainerSpec.train_extremely_large_corpus', + index=44, + number=49, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=False, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='seed_sentencepieces_file', + full_name='sentencepiece.TrainerSpec.seed_sentencepieces_file', + index=45, + number=54, + type=9, + cpp_type=9, + label=1, + has_default_value=True, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _TRAINERSPEC_MODELTYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=45, + serialized_end=1617, +) + + +_NORMALIZERSPEC = _descriptor.Descriptor( + name='NormalizerSpec', + full_name='sentencepiece.NormalizerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', + full_name='sentencepiece.NormalizerSpec.name', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='precompiled_charsmap', + full_name='sentencepiece.NormalizerSpec.precompiled_charsmap', + index=1, + number=2, + type=12, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b(""), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='add_dummy_prefix', + full_name='sentencepiece.NormalizerSpec.add_dummy_prefix', + index=2, + number=3, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='remove_extra_whitespaces', + full_name='sentencepiece.NormalizerSpec.remove_extra_whitespaces', + index=3, + number=4, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='escape_whitespaces', + full_name='sentencepiece.NormalizerSpec.escape_whitespaces', + index=4, + number=5, + type=8, + cpp_type=7, + label=1, + has_default_value=True, + default_value=True, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='normalization_rule_tsv', + full_name='sentencepiece.NormalizerSpec.normalization_rule_tsv', + index=5, + number=6, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1620, + serialized_end=1829, +) + + +_SELFTESTDATA_SAMPLE = _descriptor.Descriptor( + name='Sample', + full_name='sentencepiece.SelfTestData.Sample', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', + full_name='sentencepiece.SelfTestData.Sample.input', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='expected', + full_name='sentencepiece.SelfTestData.Sample.expected', + index=1, + number=2, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[], + serialized_start=1900, + serialized_end=1941, +) + +_SELFTESTDATA = _descriptor.Descriptor( + name='SelfTestData', + full_name='sentencepiece.SelfTestData', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='samples', + full_name='sentencepiece.SelfTestData.samples', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[ + _SELFTESTDATA_SAMPLE, + ], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1831, + serialized_end=1952, +) + + +_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor( + name='SentencePiece', + full_name='sentencepiece.ModelProto.SentencePiece', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='piece', + full_name='sentencepiece.ModelProto.SentencePiece.piece', + index=0, + number=1, + type=9, + cpp_type=9, + label=1, + has_default_value=False, + default_value=_b("").decode('utf-8'), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='score', + full_name='sentencepiece.ModelProto.SentencePiece.score', + index=1, + number=2, + type=2, + cpp_type=6, + label=1, + has_default_value=False, + default_value=float(0), + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='type', + full_name='sentencepiece.ModelProto.SentencePiece.type', + index=2, + number=3, + type=14, + cpp_type=8, + label=1, + has_default_value=True, + default_value=1, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[], + enum_types=[ + _MODELPROTO_SENTENCEPIECE_TYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=2244, + serialized_end=2454, +) + +_MODELPROTO = _descriptor.Descriptor( + name='ModelProto', + full_name='sentencepiece.ModelProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='pieces', + full_name='sentencepiece.ModelProto.pieces', + index=0, + number=1, + type=11, + cpp_type=10, + label=3, + has_default_value=False, + default_value=[], + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='trainer_spec', + full_name='sentencepiece.ModelProto.trainer_spec', + index=1, + number=2, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='normalizer_spec', + full_name='sentencepiece.ModelProto.normalizer_spec', + index=2, + number=3, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='self_test_data', + full_name='sentencepiece.ModelProto.self_test_data', + index=3, + number=4, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + _descriptor.FieldDescriptor( + name='denormalizer_spec', + full_name='sentencepiece.ModelProto.denormalizer_spec', + index=4, + number=5, + type=11, + cpp_type=10, + label=1, + has_default_value=False, + default_value=None, + message_type=None, + enum_type=None, + containing_type=None, + is_extension=False, + extension_scope=None, + serialized_options=None, + file=DESCRIPTOR, + ), + ], + extensions=[], + nested_types=[ + _MODELPROTO_SENTENCEPIECE, + ], + enum_types=[], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[ + (200, 536870912), + ], + oneofs=[], + serialized_start=1955, + serialized_end=2465, +) + +_TRAINERSPEC.fields_by_name['model_type'].enum_type = _TRAINERSPEC_MODELTYPE +_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC +_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA +_SELFTESTDATA.fields_by_name['samples'].message_type = _SELFTESTDATA_SAMPLE +_MODELPROTO_SENTENCEPIECE.fields_by_name['type'].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE +_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO +_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['pieces'].message_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['trainer_spec'].message_type = _TRAINERSPEC +_MODELPROTO.fields_by_name['normalizer_spec'].message_type = _NORMALIZERSPEC +_MODELPROTO.fields_by_name['self_test_data'].message_type = _SELFTESTDATA +_MODELPROTO.fields_by_name['denormalizer_spec'].message_type = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['TrainerSpec'] = _TRAINERSPEC +DESCRIPTOR.message_types_by_name['NormalizerSpec'] = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['SelfTestData'] = _SELFTESTDATA +DESCRIPTOR.message_types_by_name['ModelProto'] = _MODELPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TrainerSpec = _reflection.GeneratedProtocolMessageType( + 'TrainerSpec', + (_message.Message,), + dict( + DESCRIPTOR=_TRAINERSPEC, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec) + ), +) +_sym_db.RegisterMessage(TrainerSpec) + +NormalizerSpec = _reflection.GeneratedProtocolMessageType( + 'NormalizerSpec', + (_message.Message,), + dict( + DESCRIPTOR=_NORMALIZERSPEC, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec) + ), +) +_sym_db.RegisterMessage(NormalizerSpec) + +SelfTestData = _reflection.GeneratedProtocolMessageType( + 'SelfTestData', + (_message.Message,), + dict( + Sample=_reflection.GeneratedProtocolMessageType( + 'Sample', + (_message.Message,), + dict( + DESCRIPTOR=_SELFTESTDATA_SAMPLE, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample) + ), + ), + DESCRIPTOR=_SELFTESTDATA, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData) + ), +) +_sym_db.RegisterMessage(SelfTestData) +_sym_db.RegisterMessage(SelfTestData.Sample) + +ModelProto = _reflection.GeneratedProtocolMessageType( + 'ModelProto', + (_message.Message,), + dict( + SentencePiece=_reflection.GeneratedProtocolMessageType( + 'SentencePiece', + (_message.Message,), + dict( + DESCRIPTOR=_MODELPROTO_SENTENCEPIECE, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece) + ), + ), + DESCRIPTOR=_MODELPROTO, + __module__='sentencepiece_model_pb2', + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto) + ), +) +_sym_db.RegisterMessage(ModelProto) +_sym_db.RegisterMessage(ModelProto.SentencePiece) + + +DESCRIPTOR._options = None +_TRAINERSPEC.fields_by_name['mining_sentence_size']._options = None +_TRAINERSPEC.fields_by_name['training_sentence_size']._options = None +# @@protoc_insertion_point(module_scope) diff --git a/scripts/asr_eou/transcribe_speech_sharded.py b/scripts/asr_eou/transcribe_speech_sharded.py new file mode 100644 index 000000000000..d3a827ec5edc --- /dev/null +++ b/scripts/asr_eou/transcribe_speech_sharded.py @@ -0,0 +1,343 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Add the examples/asr directory to the Python path so that we can import the transcribe_speech.py file +import sys +from pathlib import Path + +nemo_root = Path(__file__).parent.parent.parent +asr_examples_dir = nemo_root / "examples" / "asr" +sys.path.insert(0, str(asr_examples_dir)) + +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass +from math import ceil +from pathlib import Path +from typing import List + +from omegaconf import ListConfig +from tqdm import tqdm +from transcribe_speech import TranscriptionConfig as SingleTranscribeConfig # type: ignore +from transcribe_speech import main as single_transcribe_main # type: ignore + +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest +from nemo.core.config import hydra_runner +from nemo.utils import logging + +""" +Transcribe audio manifests on distributed GPUs. Useful for transcription of moderate amounts of audio data. +This script also supports splitting the manifest into chunks and merging the results back together. +This script is a modified version of `transcribe_speech.py` that only takes manifest files as input. +It is useful for transcribing a large amount of audio data that does not fit into a single job. + +# Arguments + model_path: path to .nemo ASR checkpoint + pretrained_name: name of pretrained ASR model (from NGC registry) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats), can be a comma-separated list of manifest files + or a directory containing manifest files + pattern: pattern to glob the manifest files if `dataset_manifest` is a directory + output_dir: directory to write the transcriptions + + compute_langs: Bool to request language ID information (if the model supports it) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None + + (Optionally: You can limit the type of timestamp computations using below overrides) + ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) + rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) + + output_filename: Output filename where the transcriptions will be written + batch_size: batch size during inference + presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) + + cuda: Optional int to enable or disable execution of model on certain CUDA device. + allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available + amp: Bool to decide if Automatic Mixed Precision should be used during inference + audio_type: Str filetype of the audio. Supported = wav, flac, mp3 + + overwrite_transcripts: Bool which when set allows repeated transcriptions to overwrite previous results. + + ctc_decoding: Decoding sub-config for CTC. Refer to documentation for specific values. + rnnt_decoding: Decoding sub-config for RNNT. Refer to documentation for specific values. + + calculate_wer: Bool to decide whether to calculate wer/cer at end of this script + clean_groundtruth_text: Bool to clean groundtruth text + langid: Str used for convert_num_to_words during groundtruth cleaning + use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + +# Usage +ASR model can be specified by either "model_path" or "pretrained_name". +append_pred - optional. Allows you to add more than one prediction to an existing .json +pred_name_postfix - optional. The name you want to be written for the current model +Results are returned in a JSON manifest file. + +```bash +CUDA_VISIBLE_DEVICES=1 python transcribe_speech_distributed.py \ + model_path= \ + dataset_manifest="" \ + output_dir="" \ + output_filename="" \ + clean_groundtruth_text=True \ + langid='en' \ + batch_size=32 \ + timestamps=False \ + compute_langs=False \ + amp=True \ + append_pred=False \ + pred_name_postfix="" \ + split_size=10000 \ + num_nodes=1 \ + node_idx=0 \ + num_gpus_per_node=1 \ + gpu_idx=0 +``` + +If you use Slurm, you can use this params to configure the script: +```bash + gpu_idx=\$SLURM_LOCALID \ + num_gpus_per_node=\$SLURM_GPUS_ON_NODE \ + num_nodes=\$SLURM_JOB_NUM_NODES \ + node_idx=\$SLURM_NODEID +``` + +""" + + +@dataclass +class TranscriptionConfig(SingleTranscribeConfig): + """ + Transcription Configuration for audio to text transcription. + """ + + # General configs + pattern: str = "*.json" + output_dir: str = "transcribe_output/" + + # Distributed config + num_nodes: int = 1 # total number of nodes + node_idx: int = 0 # index of the current node + num_gpus_per_node: int = 1 # number of GPUs per node + gpu_idx: int = 0 # index of the current GPU + bind_gpu_to_cuda: bool = ( + False # If False, the script will just do .cuda() on the model, otherwise it will do .to(f"cuda:{gpu_idx}") + ) + + # handle long manifest + split_size: int = -1 # -1 means no split, otherwise split the manifest into chunks of this size + + +def get_unfinished_manifest(manifest_list: List[Path], output_dir: Path): + """ + Get the manifest files that have not finished processing yet, including those that are partly processed. + + Args: + manifest_list: List of manifest files to process. + output_dir: Directory to write the transcriptions. + + Returns: + List of manifest files that have not finished processing yet. + """ + unfinished = [] + for manifest_file in manifest_list: + output_manifest_file = output_dir / manifest_file.name + if not output_manifest_file.exists(): + unfinished.append(manifest_file) + return sorted(unfinished) + + +def get_manifest_for_current_rank( + manifest_list: List[Path], gpu_id: int = 0, num_gpu: int = 1, node_idx: int = 0, num_node: int = 1 +): + """ + Get the manifest files for the current rank. + + Args: + manifest_list: List of manifest files to process. + gpu_id: ID of the current GPU. + num_gpu: Number of GPUs per node. + node_idx: Index of the current node. + num_node: Total number of nodes. + + Returns: + List of manifest files for the current rank. + """ + node_manifest_list = [] + assert num_node > 0, f"num_node ({num_node}) must be greater than 0" + assert num_gpu > 0, f"num_gpu ({num_gpu}) must be greater than 0" + assert 0 <= gpu_id < num_gpu, f"gpu_id ({gpu_id}) must be in range [0, {num_gpu})" + assert 0 <= node_idx < num_node, f"node_idx ({node_idx}) must be in range [0, {num_node})" + for i, manifest_file in enumerate(manifest_list): + if (i + node_idx) % num_node == 0: + node_manifest_list.append(manifest_file) + + gpu_manifest_list = [] + for i, manifest_file in enumerate(node_manifest_list): + if (i + gpu_id) % num_gpu == 0: + gpu_manifest_list.append(manifest_file) + return gpu_manifest_list + + +def maybe_split_manifest(manifest_list: List[Path], cfg: TranscriptionConfig) -> List[Path]: + """ + Split the manifest files into chunks of the specified size. + + Args: + manifest_list: List of manifest files to process. + cfg: Configuration. + + Returns: + List of sharded manifest files. + """ + if cfg.split_size is None or cfg.split_size <= 0: + return manifest_list + + all_sharded_manifest_files = [] + sharded_manifest_dir = Path(cfg.output_dir) / "sharded_manifest_todo" + sharded_manifest_dir.mkdir(parents=True, exist_ok=True) + + sharded_manifest_done_dir = Path(cfg.output_dir) / "sharded_manifest_done" + sharded_manifest_done_dir.mkdir(parents=True, exist_ok=True) + cfg.output_dir = sharded_manifest_done_dir + + logging.info(f"Splitting {len(manifest_list)} manifest files by every {cfg.split_size} samples.") + for manifest_file in tqdm(manifest_list, total=len(manifest_list), desc="Splitting manifest files"): + manifest = read_manifest(manifest_file) + + num_chunks = ceil(len(manifest) / cfg.split_size) + for i in range(num_chunks): + chunk_manifest = manifest[i * cfg.split_size : (i + 1) * cfg.split_size] + sharded_manifest_file = sharded_manifest_dir / f"{manifest_file.stem}--tpart_{i}.json" + write_manifest(sharded_manifest_file, chunk_manifest) + all_sharded_manifest_files.append(sharded_manifest_file) + + return all_sharded_manifest_files + + +def maybe_merge_manifest(cfg: TranscriptionConfig): + """ + Merge the sharded manifest files back into the original manifest files and write them to the output directory. + + Args: + cfg: Configuration. + + Returns: + None. + """ + if cfg.split_size is None or cfg.split_size <= 0: + return + + # only merge manifest on the first GPU of the first node + if not (cfg.gpu_idx == 0 and cfg.node_idx == 0): + return + + sharded_manifest_dir = Path(cfg.output_dir) + sharded_manifests = list(sharded_manifest_dir.glob("*--tpart_*.json")) + if not sharded_manifests: + logging.info(f"No sharded manifest files found in {sharded_manifest_dir}") + return + + logging.info(f"Merging {len(sharded_manifests)} sharded manifest files.") + manifest_dict = defaultdict(list) + for sharded_manifest in sharded_manifests: + data_name = sharded_manifest.stem.split("--tpart_")[0] + manifest_dict[data_name].append(sharded_manifest) + + output_dir = Path(cfg.output_dir).parent + for data_name, sharded_manifest_list in tqdm( + manifest_dict.items(), total=len(manifest_dict), desc="Merging manifest files" + ): + merged_manifest = [] + for sharded_manifest in sharded_manifest_list: + manifest = read_manifest(sharded_manifest) + merged_manifest.extend(manifest) + output_manifest = output_dir / f"{data_name}.json" + write_manifest(output_manifest, merged_manifest) + logging.info(f"Merged manifest files saved to {output_dir}") + + +@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) +def run_distributed_transcribe(cfg: TranscriptionConfig): + """ + Run distributed transcription with the given configuration. + """ + logging.info(f"Running distributed transcription with config: {cfg}") + + if cfg.dataset_manifest is None: + raise ValueError("`dataset_manifest` is required") + + # load the manifest + if isinstance(cfg.dataset_manifest, str) and "," in cfg.dataset_manifest: + manifest_list = cfg.dataset_manifest.split(",") + elif isinstance(cfg.dataset_manifest, (ListConfig, list)): + manifest_list = cfg.dataset_manifest + else: + input_manifest = Path(cfg.dataset_manifest) + if input_manifest.is_dir(): + manifest_list = list(input_manifest.glob(cfg.pattern)) + elif input_manifest.is_file(): + manifest_list = [input_manifest] + else: + raise ValueError(f"Invalid manifest file or directory: {input_manifest}") + + if not manifest_list: + raise ValueError(f"No manifest files found matching pattern: {cfg.pattern} in {input_manifest}") + + manifest_list = maybe_split_manifest(manifest_list, cfg) + original_manifest_list = list(manifest_list) + logging.info(f"Found {len(manifest_list)} manifest files.") + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + unfinished_manifest = get_unfinished_manifest(manifest_list, output_dir=output_dir) + if not unfinished_manifest: + maybe_merge_manifest(cfg) + logging.info("All manifest files have been processed. Exiting.") + return + logging.info(f"Found {len(unfinished_manifest)} unfinished manifest files.") + + manifest_list = get_manifest_for_current_rank( + unfinished_manifest, + gpu_id=cfg.gpu_idx, + num_gpu=cfg.num_gpus_per_node, + node_idx=cfg.node_idx, + num_node=cfg.num_nodes, + ) + if not manifest_list: + logging.info(f"No manifest files found for GPU {cfg.gpu_idx} on node {cfg.node_idx}. Exiting.") + return + + logging.info(f"Processing {len(manifest_list)} manifest files with GPU {cfg.gpu_idx} on node {cfg.node_idx}.") + + cfg.cuda = cfg.gpu_idx if cfg.bind_gpu_to_cuda else None + for manifest_file in tqdm(manifest_list): + logging.info(f"Processing {manifest_file}...") + output_filename = output_dir / Path(manifest_file).name + curr_cfg = deepcopy(cfg) + curr_cfg.dataset_manifest = str(manifest_file) + curr_cfg.output_filename = str(output_filename) + + single_transcribe_main(curr_cfg) + + # check if all manifest files have been processed + unfinished_manifest = get_unfinished_manifest(original_manifest_list, output_dir=output_dir) + if not unfinished_manifest: + maybe_merge_manifest(cfg) + logging.info("All manifest files have been processed. Exiting.") + return + + +if __name__ == '__main__': + run_distributed_transcribe() # noqa pylint: disable=no-value-for-parameter diff --git a/scripts/checkpoint_averaging/checkpoint_averaging.py b/scripts/checkpoint_averaging/checkpoint_averaging.py old mode 100644 new mode 100755 diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index e8562e686671..50c0de65985b 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -85,7 +85,7 @@ from dataclasses import dataclass, field from datetime import datetime from io import BytesIO -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import numpy as np import soundfile @@ -563,15 +563,42 @@ def create_concatenated_dataset( metadata_yaml = OmegaConf.structured(metadata) OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) - def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): + def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): """Read and filters data from the manifest""" + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + + if isinstance(manifest_path, str): + manifest_paths = manifest_path.split(",") + else: + manifest_paths = manifest_path + + print(f"Found {len(manifest_paths)} manifest files to be processed") + for manifest_file in manifest_paths: + entries_i, total_dur_i, filtered_ent_i, filtered_dur_i = self._read_single_manifest( + str(manifest_file), config + ) + entries.extend(entries_i) + total_duration += total_dur_i + filtered_entries.extend(filtered_ent_i) + filtered_duration += filtered_dur_i + + return entries, total_duration, filtered_entries, filtered_duration + + def _read_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): # Read the existing manifest entries = [] total_duration = 0.0 filtered_entries = [] filtered_duration = 0.0 + print(f"Reading manifest: {manifest_path}") with open(manifest_path, 'r', encoding='utf-8') as m: for line in m: + line = line.strip() + if not line: + continue entry = json.loads(line) audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" if config.slice_with_offset and "offset" not in entry: diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index b96fbec6ac46..65e17fda7bdf 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -17,7 +17,6 @@ import math import sys from numbers import Number -from typing import Iterable, Literal import click import lightning.pytorch as pl @@ -375,7 +374,10 @@ def oomptimizer( config_path is None and module_name is None ), "--pretrained-name cannot be used together with --module-name/--config-path" click.echo(f"Intializing ASR model from pretrained checkpoint {pretrained_name}.") - model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) + if pretrained_name.endswith('.nemo'): + model = ASRModel.restore_from(pretrained_name, trainer=trainer).to(device) + else: + model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) else: assert config_path is not None, "--module-name requires --config-path to be specified as well." assert module_name is not None, "--config-path requires --module-name to be specified as well." diff --git a/tests/collections/asr/test_asr_eou.py b/tests/collections/asr/test_asr_eou.py new file mode 100644 index 000000000000..7b77e6cdf988 --- /dev/null +++ b/tests/collections/asr/test_asr_eou.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import numpy as np +import pytest + +from nemo.collections.asr.parts.utils.eou_utils import EOUResult, cal_eou_metrics_from_frame_labels + + +def make_eou_frame_labels(duration: float, eou_time: float, frame_len_in_secs: float = 0.08) -> List[float]: + """ + Make EOU frame labels. + Args: + duration (float): Duration of the audio in seconds. + eou_time (float): Time of the EOU in seconds. + frame_len_in_secs (float): Length of each frame in seconds. + Returns: + List[float]: List of EOU frame labels. + """ + if eou_time < 0 or eou_time > duration: + raise ValueError(f"EOU time ({eou_time}) is out of range for duration ({duration}).") + + labels = [0] * int(np.ceil(duration / frame_len_in_secs) + 1) + labels[int(np.ceil(eou_time / frame_len_in_secs))] = 1 + return labels + + +class TestEOUMetrics: + @pytest.mark.unit + def test_cal_eou_metrics_from_frame_labels(self): + duration = 1.6 + eou_time = 0.64 + frame_len_in_secs = 0.08 + ref_labels = make_eou_frame_labels(duration, eou_time, frame_len_in_secs) + + # Test case 1: Early cutoff + pred_eou_time = 0.32 + preds = make_eou_frame_labels(duration, pred_eou_time, frame_len_in_secs) + eou_metrics: EOUResult = cal_eou_metrics_from_frame_labels( + prediction=preds, reference=ref_labels, frame_len_in_secs=frame_len_in_secs + ) + assert eou_metrics.true_positives == 0 + assert eou_metrics.false_positives == 1 + assert eou_metrics.false_negatives == 0 + assert eou_metrics.num_utterances == 1 + assert eou_metrics.num_predictions == 1 + assert eou_metrics.missing == 0 + assert eou_metrics.latency == [] + assert np.isclose(eou_metrics.early_cutoff, [0.32]) + + # Test case 2: Latency + pred_eou_time = 0.96 + preds = make_eou_frame_labels(duration, pred_eou_time, frame_len_in_secs) + eou_metrics: EOUResult = cal_eou_metrics_from_frame_labels( + prediction=preds, reference=ref_labels, frame_len_in_secs=frame_len_in_secs + ) + assert eou_metrics.true_positives == 0 + assert eou_metrics.false_positives == 0 + assert eou_metrics.false_negatives == 1 + assert eou_metrics.num_utterances == 1 + assert eou_metrics.num_predictions == 1 + assert eou_metrics.missing == 0 + assert np.isclose(eou_metrics.latency, [0.32]) + assert eou_metrics.early_cutoff == [] + + # Test case 3: miss detection + preds = [0] * len(ref_labels) + eou_metrics: EOUResult = cal_eou_metrics_from_frame_labels( + prediction=preds, reference=ref_labels, frame_len_in_secs=frame_len_in_secs + ) + assert eou_metrics.true_positives == 0 + assert eou_metrics.false_positives == 0 + assert eou_metrics.false_negatives == 1 + assert eou_metrics.num_utterances == 1 + assert eou_metrics.num_predictions == 0 + assert eou_metrics.missing == 1 + assert eou_metrics.latency == [] + assert eou_metrics.early_cutoff == [] diff --git a/tools/nemo_forced_aligner/align_eou.py b/tools/nemo_forced_aligner/align_eou.py new file mode 100644 index 000000000000..f40fa7eadaec --- /dev/null +++ b/tools/nemo_forced_aligner/align_eou.py @@ -0,0 +1,582 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import json +import math +import os +import shutil +import unicodedata +import uuid +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from string import punctuation +from typing import List, Optional + +import torch +from omegaconf import OmegaConf +from utils.data_prep import ( + get_batch_starts_ends, + get_manifest_lines_batch, + is_entry_in_all_lines, + is_entry_in_any_lines, +) +from utils.make_ass_files import make_ass_files +from utils.make_ctm_files import make_ctm_files +from utils.make_output_manifest import write_manifest_out_line + +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel +from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR +from nemo.collections.asr.parts.utils.transcribe_utils import setup_model +from nemo.core.config import hydra_runner +from nemo.utils import logging + +try: + from nemo.collections.asr.parts.utils.aligner_utils import ( + add_t_start_end_to_utt_obj, + get_batch_variables, + viterbi_decoding, + ) +except ImportError: + raise ImportError( + "Missing required dependency for NFA. " + "Install NeMo with NFA utilities support:\n" + " pip install 'nemo_toolkit[all]>=2.5.0'\n" + "Or install the latest development version:\n" + " pip install git+https://github.com/NVIDIA/NeMo.git" + ) + +""" +Align the utterances in manifest_filepath. +Results are saved in ctm files in output_dir as well as json manifest in output_manifest_filepath. +If no output_manifest_filepath is specified, it will save the results in the same parent directory as +the input manifest_filepath. + +Arguments: + pretrained_name: string specifying the name of a CTC NeMo ASR model which will be automatically downloaded + from NGC and used for generating the log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + model_path: string specifying the local filepath to a CTC NeMo ASR model which will be used to generate the + log-probs which we will use to do alignment. + Note: NFA can only use CTC models (not Transducer models) at the moment. + Note: if a model_path is provided, it will override the pretrained_name. + manifest_filepath: filepath to the manifest of the data you want to align, + containing 'audio_filepath' and 'text' fields. + output_dir: the folder where output CTM files and new JSON manifest will be saved. + output_manifest_filepath: Optional[str] = None # output of manfiest with sou_time and eou_time + manifest_pattern: Optional[str] = None # pattern used in Path.glob() for finding manifests + + align_using_pred_text: if True, will transcribe the audio using the specified model and then use that transcription + as the reference text for the forced alignment. + transcribe_device: None, or a string specifying the device that will be used for generating log-probs (i.e. "transcribing"). + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + viterbi_device: None, or string specifying the device that will be used for doing Viterbi decoding. + The string needs to be in a format recognized by torch.device(). If None, NFA will set it to 'cuda' if it is available + (otherwise will set it to 'cpu'). + batch_size: int specifying batch size that will be used for generating log-probs and doing Viterbi decoding. + use_local_attention: boolean flag specifying whether to try to use local attention for the ASR Model (will only + work if the ASR Model is a Conformer model). If local attention is used, we will set the local attention context + size to [64,64]. + additional_segment_grouping_separator: an optional string used to separate the text into smaller segments. + If this is not specified, then the whole text will be treated as a single segment. + remove_blank_tokens_from_ctm: a boolean denoting whether to remove tokens from token-level output CTMs. + audio_filepath_parts_in_utt_id: int specifying how many of the 'parts' of the audio_filepath + we will use (starting from the final part of the audio_filepath) to determine the + utt_id that will be used in the CTM files. Note also that any spaces that are present in the audio_filepath + will be replaced with dashes, so as not to change the number of space-separated elements in the + CTM files. + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 1 => utt_id will be "e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 2 => utt_id will be "d_e1" + e.g. if audio_filepath is "/a/b/c/d/e 1.wav" and audio_filepath_parts_in_utt_id is 3 => utt_id will be "c_d_e1" + use_buffered_infer: False, if set True, using streaming to do get the logits for alignment + This flag is useful when aligning large audio file. + However, currently the chunk streaming inference does not support batch inference, + which means even you set batch_size > 1, it will only infer one by one instead of doing + the whole batch inference together. + chunk_len_in_secs: float chunk length in seconds + total_buffer_in_secs: float Length of buffer (chunk + left and right padding) in seconds + chunk_batch_size: int batch size for buffered chunk inference, + which will cut one audio into segments and do inference on chunk_batch_size segments at a time + + simulate_cache_aware_streaming: False, if set True, using cache aware streaming to do get the logits for alignment + + save_output_file_formats: List of strings specifying what type of output files to save (default: ["ctm", "ass"]) + ctm_file_config: CTMFileConfig to specify the configuration of the output CTM files + ass_file_config: ASSFileConfig to specify the configuration of the output ASS files +""" + + +@dataclass +class CTMFileConfig: + remove_blank_tokens: bool = False + # minimum duration (in seconds) for timestamps in the CTM.If any line in the CTM has a + # duration lower than this, it will be enlarged from the middle outwards until it + # meets the minimum_timestamp_duration, or reaches the beginning or end of the audio file. + # Note that this may cause timestamps to overlap. + minimum_timestamp_duration: float = 0 + + +@dataclass +class ASSFileConfig: + fontsize: int = 20 + vertical_alignment: str = "center" + # if resegment_text_to_fill_space is True, the ASS files will use new segments + # such that each segment will not take up more than (approximately) max_lines_per_segment + # when the ASS file is applied to a video + resegment_text_to_fill_space: bool = False + max_lines_per_segment: int = 2 + text_already_spoken_rgb: List[int] = field(default_factory=lambda: [49, 46, 61]) # dark gray + text_being_spoken_rgb: List[int] = field(default_factory=lambda: [57, 171, 9]) # dark green + text_not_yet_spoken_rgb: List[int] = field(default_factory=lambda: [194, 193, 199]) # light gray + + +@dataclass +class AlignmentConfig: + # Required configs + pretrained_name: Optional[str] = None + model_path: Optional[str] = None + manifest_filepath: Optional[str] = None # path to manifest file or directory + output_dir: Optional[str] = '.tmp' # set it to .tmp and will be removed after alignment + output_manifest_filepath: Optional[str] = None # output of manfiest with sou_time and eou_time + manifest_pattern: Optional[str] = None # pattern used in Path.glob() for finding manifests + + # General configs + align_using_pred_text: bool = False + transcribe_device: Optional[str] = None + viterbi_device: Optional[str] = None + batch_size: int = 1 + use_local_attention: bool = True + additional_segment_grouping_separator: Optional[str] = None + audio_filepath_parts_in_utt_id: int = 4 + + # Buffered chunked streaming configs + use_buffered_chunked_streaming: bool = False + chunk_len_in_secs: float = 1.6 + total_buffer_in_secs: float = 4.0 + chunk_batch_size: int = 32 + + # Cache aware streaming configs + simulate_cache_aware_streaming: Optional[bool] = False + + # Output file configs + save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"]) + ctm_file_config: CTMFileConfig = field(default_factory=lambda: CTMFileConfig()) + ass_file_config: ASSFileConfig = field(default_factory=lambda: ASSFileConfig()) + + # remove tmp dir after alignment + remove_tmp_dir: bool = False + clean_text: bool = True + + # For multi-node multi-gpu processing + num_nodes: int = 1 # total num of nodes/machines + num_gpus: int = 1 # num of GPUs per node/machine + node_idx: int = 0 # current node index + gpu_idx: int = 0 # current GPU index + + +def unicode_to_ascii(text: str) -> str: + """ + Converts text with accented or special Latin characters (e.g., ó, ñ, ū, ō) + into their closest ASCII equivalents. + """ + # Normalize the string to NFKD to separate base characters from diacritics + normalized = unicodedata.normalize('NFKD', text) + + # Encode to ASCII bytes, ignoring characters that can't be converted + ascii_bytes = normalized.encode('ascii', 'ignore') + + # Decode back to string + ascii_text = ascii_bytes.decode('ascii') + + return ascii_text + + +def drop_pnc(text): + """ + Clean the text by removing invalid characters and converting to lowercase. + + :param text: Input text. + :return: Cleaned text. + """ + valid_chars = "abcdefghijklmnopqrstuvwxyz'" + text = text.lower() + text = unicode_to_ascii(text) + text = text.replace(":", " ") + text = ''.join([c for c in text if c in valid_chars or c.isspace() or c == "'"]) + return " ".join(text.split()).strip() + + +def clean_text(manifest: List[dict]): + """ + Clean the text in the manifest. + Args: + manifest: List of dictionaries with the text to clean. + + Returns: + manifest: List of dictionaries with the cleaned text. + """ + punctuations = punctuation.replace("'", "") + # replace_with_space = [char for char in '/?*\",.:=?_{|}~¨«·»¡¿„…‧‹›≪≫!:;ː→'] + replace_with_blank = [char for char in '`¨´‘’“”`ʻ‘’“"‘”'] + replace_with_apos = [char for char in '‘’ʻ‘’‘'] + + for i in range(len(manifest)): + text = manifest[i]["text"].strip().lower() # type: str + text = text.translate(str.maketrans("", "", punctuations)) + text = drop_pnc(text) + for c in replace_with_blank: + text = text.replace(c, "") + for c in replace_with_apos: + text = text.replace(c, "'") + manifest[i]["text"] = text + return manifest + + +def get_manifests_for_this_rank(manifest_list, num_nodes, num_gpus, node_idx, gpu_idx): + """ + Get the manifest files for this rank. + """ + if len(manifest_list) == 0: + return manifest_list + + assert num_nodes > 0, "num_nodes must be greater than 0" + assert num_gpus > 0, "num_gpus must be greater than 0" + assert 0 <= node_idx < num_nodes, f"node_idx {node_idx} must be between 0 and {num_nodes - 1}" + assert 0 <= gpu_idx < num_gpus, f"gpu_idx {gpu_idx} must be between 0 and {num_gpus - 1}" + + manifests_this_node = [] + for i, manifest_file in enumerate(manifest_list): + if num_nodes > 1: + if i % num_nodes == node_idx: + manifests_this_node.append(manifest_file) + else: + manifests_this_node.append(manifest_file) + + manifests_this_gpu = [] + for i, manifest_file in enumerate(manifests_this_node): + if num_gpus > 1: + if i % num_gpus == gpu_idx: + manifests_this_gpu.append(manifest_file) + else: + manifests_this_gpu.append(manifest_file) + return manifests_this_gpu + + +@hydra_runner(config_name="AlignmentConfig", schema=AlignmentConfig) +def main(cfg: AlignmentConfig): + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + # Validate config + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None") + + if cfg.model_path is not None and cfg.pretrained_name is not None: + raise ValueError("One of cfg.model_path and cfg.pretrained_name must be None") + + if cfg.manifest_filepath is None: + raise ValueError("cfg.manifest_filepath must be specified") + + if cfg.output_dir is None and not cfg.remove_tmp_dir: + raise ValueError("cfg.output_dir must be specified if cfg.remove_tmp_dir is False") + + if cfg.batch_size < 1: + raise ValueError("cfg.batch_size cannot be zero or a negative number") + + if cfg.additional_segment_grouping_separator == "" or cfg.additional_segment_grouping_separator == " ": + raise ValueError("cfg.additional_grouping_separator cannot be empty string or space character") + + if cfg.ctm_file_config.minimum_timestamp_duration < 0: + raise ValueError("cfg.minimum_timestamp_duration cannot be a negative number") + + if cfg.ass_file_config.vertical_alignment not in ["top", "center", "bottom"]: + raise ValueError("cfg.ass_file_config.vertical_alignment must be one of 'top', 'center' or 'bottom'") + + for rgb_list in [ + cfg.ass_file_config.text_already_spoken_rgb, + cfg.ass_file_config.text_already_spoken_rgb, + cfg.ass_file_config.text_already_spoken_rgb, + ]: + if len(rgb_list) != 3: + raise ValueError( + "cfg.ass_file_config.text_already_spoken_rgb," + " cfg.ass_file_config.text_being_spoken_rgb," + " and cfg.ass_file_config.text_already_spoken_rgb all need to contain" + " exactly 3 elements." + ) + + # init devices + if cfg.transcribe_device is None: + transcribe_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + transcribe_device = torch.device(cfg.transcribe_device) + logging.info(f"Device to be used for transcription step (`transcribe_device`) is {transcribe_device}") + + if cfg.viterbi_device is None: + viterbi_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + viterbi_device = torch.device(cfg.viterbi_device) + logging.info(f"Device to be used for viterbi step (`viterbi_device`) is {viterbi_device}") + + if transcribe_device.type == 'cuda' or viterbi_device.type == 'cuda': + logging.warning( + 'One or both of transcribe_device and viterbi_device are GPUs. If you run into OOM errors ' + 'it may help to change both devices to be the CPU.' + ) + + # load model + model, _ = setup_model(cfg, transcribe_device) + model.eval() + + if isinstance(model, EncDecHybridRNNTCTCModel): + model.change_decoding_strategy(decoder_type="ctc") + + if cfg.use_local_attention: + logging.info( + "Flag use_local_attention is set to True => will try to use local attention for model if it allows it" + ) + model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=[64, 64]) + + if not (isinstance(model, EncDecCTCModel) or isinstance(model, EncDecHybridRNNTCTCModel)): + raise NotImplementedError( + f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel." + " Currently only instances of these models are supported" + ) + + if cfg.ctm_file_config.minimum_timestamp_duration > 0: + logging.warning( + f"cfg.ctm_file_config.minimum_timestamp_duration has been set to {cfg.ctm_file_config.minimum_timestamp_duration} seconds. " + "This may cause the alignments for some tokens/words/additional segments to be overlapping." + ) + + buffered_chunk_params = {} + if cfg.use_buffered_chunked_streaming: + model_cfg = copy.deepcopy(model._cfg) + + OmegaConf.set_struct(model_cfg.preprocessor, False) + # some changes for streaming scenario + model_cfg.preprocessor.dither = 0.0 + model_cfg.preprocessor.pad_to = 0 + + if model_cfg.preprocessor.normalize != "per_feature": + logging.error( + "Only EncDecCTCModelBPE models trained with per_feature normalization are supported currently" + ) + # Disable config overwriting + OmegaConf.set_struct(model_cfg.preprocessor, True) + + feature_stride = model_cfg.preprocessor['window_stride'] + model_stride_in_secs = feature_stride * cfg.model_downsample_factor + total_buffer = cfg.total_buffer_in_secs + chunk_len = float(cfg.chunk_len_in_secs) + tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs) + mid_delay = math.ceil((chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs) + logging.info(f"tokens_per_chunk is {tokens_per_chunk}, mid_delay is {mid_delay}") + + model = FrameBatchASR( + asr_model=model, + frame_len=chunk_len, + total_buffer=cfg.total_buffer_in_secs, + batch_size=cfg.chunk_batch_size, + ) + buffered_chunk_params = { + "delay": mid_delay, + "model_stride_in_secs": model_stride_in_secs, + "tokens_per_chunk": tokens_per_chunk, + } + + if Path(cfg.manifest_filepath).is_file(): + manifest_list = [cfg.manifest_filepath] + elif Path(cfg.manifest_filepath).is_dir(): + if cfg.manifest_pattern is not None: + manifest_list = list(Path(cfg.manifest_filepath).glob(cfg.manifest_pattern)) + else: + manifest_list = list(Path(cfg.manifest_filepath).glob("*.json")) + else: + raise ValueError( + f"cfg.manifest_filepath is not a valid file or directory. " + f"Please check the path: {cfg.manifest_filepath}" + ) + + origin_output_manifest_filepath = cfg.output_manifest_filepath + + manifest_list = get_manifests_for_this_rank(manifest_list, cfg.num_nodes, cfg.num_gpus, cfg.node_idx, cfg.gpu_idx) + logging.info(f"Found {len(manifest_list)} manifest files to process.") + # process each manifest file + for manifest_filepath in manifest_list: + logging.info(f"Processing manifest file: {manifest_filepath}") + cfg.manifest_filepath = str(manifest_filepath) + + if origin_output_manifest_filepath is None: + manifest_stem = Path(manifest_filepath).stem + cfg.output_manifest_filepath = str(Path(manifest_filepath).parent / f"{manifest_stem}-aligned.json") + elif len(manifest_list) > 1 and origin_output_manifest_filepath is not None: + raise ValueError( + "cfg.output_manifest_filepath must be None when processing multiple manifest files. " + "Please set it to None." + ) + + if not cfg.remove_tmp_dir and len(manifest_list) > 1: + # if keep alignment files, then we need to set output_dir to be different for each manifest + cfg.output_dir = str(Path(manifest_filepath).parent / f"{Path(manifest_filepath).stem}_alignment") + + process_single_manifest(cfg, model, buffered_chunk_params, viterbi_device) + logging.info(f"Output manifest saved to: {cfg.output_manifest_filepath}") + + logging.info("All manifest files processed successfully.") + + +def process_single_manifest(cfg: AlignmentConfig, model, buffered_chunk_params, viterbi_device): + # Validate manifest contents + if not is_entry_in_all_lines(cfg.manifest_filepath, "audio_filepath"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain an 'audio_filepath' entry. " + "All lines must contain an 'audio_filepath' entry." + ) + + if cfg.align_using_pred_text: + if is_entry_in_any_lines(cfg.manifest_filepath, "pred_text"): + raise RuntimeError( + "Cannot specify cfg.align_using_pred_text=True when the manifest at cfg.manifest_filepath " + "contains 'pred_text' entries. This is because the audio will be transcribed and may produce " + "a different 'pred_text'. This may cause confusion." + ) + else: + if not is_entry_in_all_lines(cfg.manifest_filepath, "text"): + raise RuntimeError( + "At least one line in cfg.manifest_filepath does not contain a 'text' entry. " + "NFA requires all lines to contain a 'text' entry when cfg.align_using_pred_text=False." + ) + + # get start and end line IDs of batches + starts, ends = get_batch_starts_ends(cfg.manifest_filepath, cfg.batch_size) + + # init output_timestep_duration = None and we will calculate and update it during the first batch + output_timestep_duration = None + + if cfg.remove_tmp_dir and cfg.output_dir is None: + cfg.output_dir = f"alignment-{uuid.uuid4()}" + + # init f_manifest_out + os.makedirs(cfg.output_dir, exist_ok=True) + tgt_manifest_name = str(Path(cfg.manifest_filepath).stem) + "_with_output_file_paths.json" + tgt_manifest_filepath = str(Path(cfg.output_dir) / tgt_manifest_name) + f_manifest_out = open(tgt_manifest_filepath, 'w') + + # get alignment and save in CTM batch-by-batch + for start, end in zip(starts, ends): + manifest_lines_batch = get_manifest_lines_batch(cfg.manifest_filepath, start, end) + + if cfg.clean_text: + manifest_lines_batch = clean_text(manifest_lines_batch) + ( + log_probs_batch, + y_batch, + T_batch, + U_batch, + utt_obj_batch, + output_timestep_duration, + ) = get_batch_variables( + manifest_lines_batch, + model, + cfg.additional_segment_grouping_separator, + cfg.align_using_pred_text, + cfg.audio_filepath_parts_in_utt_id, + output_timestep_duration, + cfg.simulate_cache_aware_streaming, + cfg.use_buffered_chunked_streaming, + buffered_chunk_params, + ) + + alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, viterbi_device) + + for utt_obj, alignment_utt in zip(utt_obj_batch, alignments_batch): + + utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration) + + if "ctm" in cfg.save_output_file_formats: + utt_obj = make_ctm_files( + utt_obj, + cfg.output_dir, + cfg.ctm_file_config, + ) + + if "ass" in cfg.save_output_file_formats: + utt_obj = make_ass_files(utt_obj, cfg.output_dir, cfg.ass_file_config) + + write_manifest_out_line( + f_manifest_out, + utt_obj, + ) + + f_manifest_out.close() + + # adding eou processing here + input_manifest_lines = [] + with open(cfg.manifest_filepath, 'r') as f: + for line in f.readlines(): + if line.strip(): + input_manifest_lines.append(json.loads(line)) + + output_manifest_lines = [] + with open(tgt_manifest_filepath, 'r') as f: + for i, line in enumerate(f.readlines()): + item = json.loads(line) + assert os.path.basename(input_manifest_lines[i]['audio_filepath']) == os.path.basename( + item['audio_filepath'] + ) + + if 'segments_level_ctm_filepath' not in item: + print( + f"`segments_level_ctm_filepath` not found for {input_manifest_lines[i]['audio_filepath']}, skipping" + ) + continue + + # get sou/eou time + with open(item['segments_level_ctm_filepath']) as f: + lines = [line.split() for line in f] + start_time = min([float(line[2]) for line in lines]) + end_time = max([float(line[2]) + float(line[3]) for line in lines]) + input_manifest_lines[i]['sou_time'] = start_time + input_manifest_lines[i]['eou_time'] = end_time + output_manifest_lines.append(input_manifest_lines[i]) + + with open(cfg.output_manifest_filepath, 'w') as f: + for item in output_manifest_lines: + f.write(json.dumps(item) + '\n') + + if cfg.remove_tmp_dir: # safely removing tmp dir after alignment + for file_or_folder in [ + tgt_manifest_filepath, + os.path.join(cfg.output_dir, 'ctm'), + os.path.join(cfg.output_dir, 'ass'), + ]: + if os.path.exists(file_or_folder): + if os.path.isfile(file_or_folder): + os.remove(file_or_folder) + else: + shutil.rmtree(file_or_folder) + if os.path.exists(cfg.output_dir) and len(os.listdir(cfg.output_dir)) == 0: + shutil.rmtree(cfg.output_dir) + + return None + + +if __name__ == "__main__": + main() diff --git a/tools/nemo_forced_aligner/utils/data_prep.py b/tools/nemo_forced_aligner/utils/data_prep.py index 3386b5744108..05900899c74e 100644 --- a/tools/nemo_forced_aligner/utils/data_prep.py +++ b/tools/nemo_forced_aligner/utils/data_prep.py @@ -69,6 +69,7 @@ def get_manifest_lines_batch(manifest_filepath, start, end): for line_i, line in enumerate(f): if line_i >= start and line_i <= end: data = json.loads(line) + data["audio_filepath"] = get_full_path(data["audio_filepath"], manifest_filepath) if "text" in data: # remove any BOM, any duplicated spaces, convert any # newline chars to spaces From 3784ec7088ca1048b0d3e4c09559cf9d3a6a0176 Mon Sep 17 00:00:00 2001 From: subhankar-ghosh Date: Thu, 2 Apr 2026 09:15:29 -0700 Subject: [PATCH 09/10] Fix freesound url Signed-off-by: subhankar-ghosh --- docs/source/asr/speech_classification/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/asr/speech_classification/datasets.rst b/docs/source/asr/speech_classification/datasets.rst index fc16e114086d..fa3ba427ee73 100644 --- a/docs/source/asr/speech_classification/datasets.rst +++ b/docs/source/asr/speech_classification/datasets.rst @@ -11,7 +11,7 @@ If you have your own data and want to preprocess it to use with NeMo ASR models, Freesound ----------- -`Freesound `_ is a website that aims to create a huge open collaborative database of audio snippets, samples, recordings, bleeps. +`Freesound `_ is a website that aims to create a huge open collaborative database of audio snippets, samples, recordings, bleeps. Most audio samples are released under Creative Commons licenses that allow their reuse. Researchers and developers can access Freesound content using the Freesound API to retrieve meaningful sound information such as metadata, analysis files, and the sounds themselves. From fee82bbfbdbbb28b4bfb0f2bca99e42cbc0f18f2 Mon Sep 17 00:00:00 2001 From: subhankar-ghosh Date: Thu, 2 Apr 2026 09:48:19 -0700 Subject: [PATCH 10/10] Fix freesound url Signed-off-by: subhankar-ghosh --- docs/source/asr/speech_classification/datasets.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/asr/speech_classification/datasets.rst b/docs/source/asr/speech_classification/datasets.rst index fa3ba427ee73..be29552b1dee 100644 --- a/docs/source/asr/speech_classification/datasets.rst +++ b/docs/source/asr/speech_classification/datasets.rst @@ -11,7 +11,7 @@ If you have your own data and want to preprocess it to use with NeMo ASR models, Freesound ----------- -`Freesound `_ is a website that aims to create a huge open collaborative database of audio snippets, samples, recordings, bleeps. +`Freesound `_ is a website that aims to create a huge open collaborative database of audio snippets, samples, recordings, bleeps. Most audio samples are released under Creative Commons licenses that allow their reuse. Researchers and developers can access Freesound content using the Freesound API to retrieve meaningful sound information such as metadata, analysis files, and the sounds themselves.