Skip to content

VoiceChat EA STT training reproducible features#15558

Draft
ankitapasad wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
ankitapasad:stt_vc_ea_parity
Draft

VoiceChat EA STT training reproducible features#15558
ankitapasad wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
ankitapasad:stt_vc_ea_parity

Conversation

@ankitapasad
Copy link
Copy Markdown
Collaborator

What does this PR do ?

Adds following features to the dataset class to support VoiceChat EA STT training and fine-tuning

  1. Correct agent EOS placement
  2. Clean implementation of token IDs and update user BOS ID to match EA
  3. MCQ system prompt
  4. Filler responses for ASR training data
  5. Number normalization
  6. Corresponding tests

Collection: speechlm2

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

PR Type:

  • New Feature
  • Bugfix

If you haven't finished some of the above items you can still open "Draft" PR.

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Ankita Pasad <apasad@nvidia.com>
…ization, clean-up token ID init, and corresponding tests

Co-authored-by: Claude <noreply@anthropic.com>

Signed-off-by: Ankita Pasad <apasad@nvidia.com>
import os

import pytest
import torch

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'torch' is not used.
assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS"

# Now collate source tokens, passing in the target channel for EOS placement
source_tokens, source_token_lens = collate_token_channel(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable source_tokens is not used.
assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS"

# Now collate source tokens, passing in the target channel for EOS placement
source_tokens, source_token_lens = collate_token_channel(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable source_token_lens is not used.
skip_eos=True,
)

source_tokens, source_token_lens = collate_token_channel(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable source_tokens is not used.
skip_eos=True,
)

source_tokens, source_token_lens = collate_token_channel(

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable source_token_lens is not used.

from nemo.collections.common.tokenizers import AutoTokenizer
from nemo.collections.speechlm2.data.duplex_stt_dataset import DuplexSTTDataset
from nemo.collections.speechlm2.data.utils import get_pad_id

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'get_pad_id' is not used.
train_batch = train_ds[cuts]
val_batch = val_ds[cuts]

train_targets = train_batch["audio_data"]["target_tokens"]

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable train_targets is not used.

# Force aligner should be created but never called during validation
val_ds.force_aligner = MagicMock()
val_ds[cuts]

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.
# Mock the force aligner to avoid loading wav2vec2
train_ds.force_aligner = MagicMock()
train_ds.force_aligner.batch_force_align_user_audio.side_effect = lambda cuts, **kwargs: cuts
train_ds[cuts]

Check notice

Code scanning / CodeQL

Statement has no effect Note test

This statement has no effect.
- is_mcq_cut_train / is_mcq_cut_val / is_asr_cut
"""

import pytest

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'pytest' is not used.
assert tokenizer.bos is not None, "BOS support in the tokenizer is required."
assert tokenizer.eos is not None, "EOS support in the tokenizer is required."

user_bos_token = '^'
Copy link
Copy Markdown
Collaborator

@kevinhu-nv kevinhu-nv Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use the same bos and eos for user and agent channels. I feel that is cleaner and I verified that does not impact model performance. I see you want to make exactly match EA, let's make these as configurable and one can set ^ and $? When we release the EA ckpt, we will release a config anyway to make it use ^ and $.

Prompt selection priority:
1. Per-cut custom prompt (cut.custom['system_prompt'])
2. MCQ training cut -> THINK prompt for think-cuts, NOTHINK prompt for others
3. MCQ validation cut (when add_mcq_prompt=True) -> THINK prompt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a support for custom prompt? We can then easily evaluate different demo setups we have used.

@kevinhu-nv
Copy link
Copy Markdown
Collaborator

A high-level question: Can you also share a training script/wandb to make sure metrics look roughly good? I think additional efforts may be needed to catch the EA ckpt but it is better to check at intermediate steps as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants