VoiceChat EA STT training reproducible features#15558
VoiceChat EA STT training reproducible features#15558ankitapasad wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
…ization, clean-up token ID init, and corresponding tests Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Ankita Pasad <apasad@nvidia.com>
| import os | ||
|
|
||
| import pytest | ||
| import torch |
Check notice
Code scanning / CodeQL
Unused import Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| assert (target_tokens == eos).sum().item() == 0, "skip_eos=True should not place any EOS" | ||
|
|
||
| # Now collate source tokens, passing in the target channel for EOS placement | ||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
| skip_eos=True, | ||
| ) | ||
|
|
||
| source_tokens, source_token_lens = collate_token_channel( |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| from nemo.collections.common.tokenizers import AutoTokenizer | ||
| from nemo.collections.speechlm2.data.duplex_stt_dataset import DuplexSTTDataset | ||
| from nemo.collections.speechlm2.data.utils import get_pad_id |
Check notice
Code scanning / CodeQL
Unused import Note test
| train_batch = train_ds[cuts] | ||
| val_batch = val_ds[cuts] | ||
|
|
||
| train_targets = train_batch["audio_data"]["target_tokens"] |
Check notice
Code scanning / CodeQL
Unused local variable Note test
|
|
||
| # Force aligner should be created but never called during validation | ||
| val_ds.force_aligner = MagicMock() | ||
| val_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| # Mock the force aligner to avoid loading wav2vec2 | ||
| train_ds.force_aligner = MagicMock() | ||
| train_ds.force_aligner.batch_force_align_user_audio.side_effect = lambda cuts, **kwargs: cuts | ||
| train_ds[cuts] |
Check notice
Code scanning / CodeQL
Statement has no effect Note test
| - is_mcq_cut_train / is_mcq_cut_val / is_asr_cut | ||
| """ | ||
|
|
||
| import pytest |
Check notice
Code scanning / CodeQL
Unused import Note test
| assert tokenizer.bos is not None, "BOS support in the tokenizer is required." | ||
| assert tokenizer.eos is not None, "EOS support in the tokenizer is required." | ||
|
|
||
| user_bos_token = '^' |
There was a problem hiding this comment.
I use the same bos and eos for user and agent channels. I feel that is cleaner and I verified that does not impact model performance. I see you want to make exactly match EA, let's make these as configurable and one can set ^ and
| Prompt selection priority: | ||
| 1. Per-cut custom prompt (cut.custom['system_prompt']) | ||
| 2. MCQ training cut -> THINK prompt for think-cuts, NOTHINK prompt for others | ||
| 3. MCQ validation cut (when add_mcq_prompt=True) -> THINK prompt |
There was a problem hiding this comment.
Can you also add a support for custom prompt? We can then easily evaluate different demo setups we have used.
|
A high-level question: Can you also share a training script/wandb to make sure metrics look roughly good? I think additional efforts may be needed to catch the EA ckpt but it is better to check at intermediate steps as well. |
What does this PR do ?
Adds following features to the dataset class to support VoiceChat EA STT training and fine-tuning
Collection: speechlm2
Usage
# Add a code snippet demonstrating how to use thisPR Type:
If you haven't finished some of the above items you can still open "Draft" PR.