-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Add streaming inference support for Unified RNNT model #15522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
eef5587
a032051
86189c2
7d8ce9f
714441d
b1de422
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -111,7 +111,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): | |||||||||||||||||||||||
| att_context_probs (List[float]): a list of probabilities of each one of the att_context_size | ||||||||||||||||||||||||
| when a list of them is passed. If not specified, uniform distribution is being used. | ||||||||||||||||||||||||
| Defaults to None | ||||||||||||||||||||||||
| att_context_style (str): 'regular' or 'chunked_limited'. | ||||||||||||||||||||||||
| att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training. | ||||||||||||||||||||||||
| It defines the range of Left, Middle, and Right context sizes for the attention mechanism. | ||||||||||||||||||||||||
| At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes. | ||||||||||||||||||||||||
| Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation | ||||||||||||||||||||||||
| att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'. | ||||||||||||||||||||||||
| Defaults to 'regular' | ||||||||||||||||||||||||
| xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`. | ||||||||||||||||||||||||
| Defaults to True. | ||||||||||||||||||||||||
|
|
@@ -126,6 +130,9 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): | |||||||||||||||||||||||
| `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means | ||||||||||||||||||||||||
| `[(conv_kernel_size-1), 0]`. | ||||||||||||||||||||||||
| Defaults to None. | ||||||||||||||||||||||||
| conv_context_style (str): 'regular' or 'dcc' | ||||||||||||||||||||||||
| DCC - Dynamic Chunked Convolution that is used for unified ASR training. | ||||||||||||||||||||||||
| Defaults to 'regular'. | ||||||||||||||||||||||||
| conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. | ||||||||||||||||||||||||
| When enables, the left half of the convolution kernel would get masked in streaming cases. | ||||||||||||||||||||||||
| Defaults to False. | ||||||||||||||||||||||||
|
|
@@ -305,13 +312,15 @@ def __init__( | |||||||||||||||||||||||
| n_heads=4, | ||||||||||||||||||||||||
| att_context_size=None, | ||||||||||||||||||||||||
| att_context_probs=None, | ||||||||||||||||||||||||
| att_chunk_context_size=None, | ||||||||||||||||||||||||
| att_context_style='regular', | ||||||||||||||||||||||||
| xscaling=True, | ||||||||||||||||||||||||
| untie_biases=True, | ||||||||||||||||||||||||
| pos_emb_max_len=5000, | ||||||||||||||||||||||||
| conv_kernel_size=31, | ||||||||||||||||||||||||
| conv_norm_type='batch_norm', | ||||||||||||||||||||||||
| conv_context_size=None, | ||||||||||||||||||||||||
| conv_context_style='regular', | ||||||||||||||||||||||||
| use_bias=True, | ||||||||||||||||||||||||
| dropout=0.1, | ||||||||||||||||||||||||
| dropout_pre_encoder=0.1, | ||||||||||||||||||||||||
|
|
@@ -346,6 +355,22 @@ def __init__( | |||||||||||||||||||||||
| self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends | ||||||||||||||||||||||||
| self.sync_max_audio_length = sync_max_audio_length | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| assert conv_context_style in ["regular", "dcc"], f"Invalid conv_context_style: {conv_context_style}!" | ||||||||||||||||||||||||
| self.conv_context_style = conv_context_style | ||||||||||||||||||||||||
| self.conv_kernel_size = conv_kernel_size | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Setting up the att_chunk_context_size | ||||||||||||||||||||||||
| if att_chunk_context_size is not None: | ||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||
| att_context_style == "chunked_limited_with_rc" | ||||||||||||||||||||||||
| ), "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!" | ||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||
| len(att_chunk_context_size) == 3 | ||||||||||||||||||||||||
| ), "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]" | ||||||||||||||||||||||||
| self.att_chunk_context_size = att_chunk_context_size | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| self.att_chunk_context_size = None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Setting up the att_context_size | ||||||||||||||||||||||||
| ( | ||||||||||||||||||||||||
| self.att_context_size_all, | ||||||||||||||||||||||||
|
|
@@ -716,7 +741,6 @@ def forward_internal( | |||||||||||||||||||||||
| offset=offset, | ||||||||||||||||||||||||
| device=audio_signal.device, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # saving tensors if required for interctc loss | ||||||||||||||||||||||||
| if self.is_access_enabled(getattr(self, "model_guid", None)): | ||||||||||||||||||||||||
| if self.interctc_capture_at_layers is None: | ||||||||||||||||||||||||
|
|
@@ -816,6 +840,33 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs | |||||||||||||||||||||||
| torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) | ||||||||||||||||||||||||
| elif self.att_context_style == "chunked_limited_with_rc" and sum(att_context_size) != -3: | ||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||
| len(att_context_size) == 3 | ||||||||||||||||||||||||
| ), "att_context_size must have 3 elements: [left_context, chunk_size, right_context]" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| left_context_frames = att_context_size[0] | ||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please, separate to a different function if possible |
||||||||||||||||||||||||
| chunk_size_frames = att_context_size[1] | ||||||||||||||||||||||||
| right_context_frames = att_context_size[2] | ||||||||||||||||||||||||
| assert chunk_size_frames >= 1, "chunk_size_frames must be greater than 0!" | ||||||||||||||||||||||||
| # Calculate chunk index for each frame (which processing group it belongs to) | ||||||||||||||||||||||||
| frame_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) | ||||||||||||||||||||||||
| chunk_idx = torch.div(frame_idx, chunk_size_frames, rounding_mode="trunc") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
Comment on lines
+848
to
+855
|
||||||||||||||||||||||||
| window_start = chunk_idx * chunk_size_frames - left_context_frames | ||||||||||||||||||||||||
| window_start = torch.maximum(window_start, torch.zeros_like(window_start)) | ||||||||||||||||||||||||
| window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1)) | ||||||||||||||||||||||||
| # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i] | ||||||||||||||||||||||||
| j_indices = frame_idx.unsqueeze(0) # [1, T] | ||||||||||||||||||||||||
| window_start_expanded = window_start.unsqueeze(1) # [T, 1] | ||||||||||||||||||||||||
| window_end_expanded = window_end.unsqueeze(1) # [T, 1] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| chunked_limited_mask = torch.logical_and( | ||||||||||||||||||||||||
| j_indices >= window_start_expanded, j_indices <= window_end_expanded | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) | ||||||||||||||||||||||||
|
Comment on lines
+843
to
+869
|
||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| att_mask = None | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
@@ -876,6 +927,9 @@ def _calc_context_sizes( | |||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
| att_context_size_all = [[-1, -1]] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if att_context_style == "chunked_limited_with_rc": | ||||||||||||||||||||||||
| att_context_size_all = [[-1, -1, -1]] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
| att_context_size_all = [[-1, -1, -1]] | |
| if not att_context_size: | |
| # Default to unlimited left, right, and relative context when none is provided. | |
| att_context_size_all = [[-1, -1, -1]] | |
| else: | |
| # Validate that each entry has exactly 3 elements for chunked_limited_with_rc style. | |
| for i, att_cs in enumerate(att_context_size_all): | |
| if len(att_cs) != 3: | |
| raise ValueError( | |
| f"att_context_size[{i}] must have exactly 3 elements for chunked_limited_with_rc style." | |
| ) |
Copilot
AI
Mar 25, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For att_context_style == 'chunked_limited_with_rc', _calc_context_sizes() currently overwrites any provided att_context_size with [[-1, -1, -1]], which makes the encoder ignore configured context sizes and causes set_default_att_context_size() to always warn that the provided size is unsupported. Consider parsing/validating the provided att_context_size similarly to the other styles, and only defaulting to [-1, -1, -1] when no value was provided.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add this also to
nemo/collections/asr/inference?