From eef5587d6b8cebcb557c7a58e4dbdadad1347b44 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 19 Mar 2026 05:47:13 -0700 Subject: [PATCH 1/5] add streaming inference support for unified rnnt model Signed-off-by: andrusenkoau --- .../speech_to_text_streaming_infer_rnnt.py | 8 ++ nemo/collections/asr/losses/rnnt.py | 10 +++ .../asr/modules/conformer_encoder.py | 77 ++++++++++++++++++- 3 files changed, 94 insertions(+), 1 deletion(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 298a3d5b4b55..035b84775b73 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -133,6 +133,8 @@ class TranscriptionConfig: ) right_context_secs: float = 2 # right context + att_context_size_as_chunk: bool = True # whether to use the att_context_size as chunk size (importand for extra-low latency) + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. @@ -298,6 +300,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples, ) + # unified ASR model: use the att_context_size as chunk size (important for extra-low latency) + if asr_model.cfg.encoder.att_context_style == 'chunked_limited_with_rc' and cfg.att_context_size_as_chunk: + asr_model.encoder.set_default_att_context_size( + att_context_size=[context_encoder_frames.left,context_encoder_frames.chunk,context_encoder_frames.right] + ) + logging.info( "Corrected contexts (sec): " f"Left {context_samples.left / audio_sample_rate:.2f}, " diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 894be6319c99..5eb910424342 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -153,6 +153,14 @@ class RNNTLossConfig: is_available=True, installation_msg="Pure Pytorch implementation of TDT loss. Slow and for debugging purposes only.", ), + "rnnt_triton": RNNTLossConfig( + loss_name="rnnt_triton", # will be added later + lib_name="torch", + min_version='0.0', + is_available=True, + installation_msg="Triton RNN-T loss", + force_float32=False, + ), } RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba'] @@ -322,6 +330,8 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) elif loss_name == "graph_w_transducer": loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphWTransducerLoss.__init__, ignore_params={"blank"}) loss_func = GraphWTransducerLoss(blank=blank_idx, **loss_kwargs) + elif loss_name == "rnnt_triton": + loss_func = None # will be added later else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 0e25ea7767f4..15c52d3b459a 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -111,6 +111,10 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): att_context_probs (List[float]): a list of probabilities of each one of the att_context_size when a list of them is passed. If not specified, uniform distribution is being used. Defaults to None + att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training. + It defined the range of Left, Middle, and Right context sizes for the attention mechanism. + At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes. + Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation att_context_style (str): 'regular' or 'chunked_limited'. Defaults to 'regular' xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`. @@ -126,6 +130,12 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means `[(conv_kernel_size-1), 0]`. Defaults to None. + conv_context_style (str): 'regular' or 'dcc' + DCC - Dynamic Chunked Convolution that is used for unified ASR training. + Defaults to 'regular'. + att_zero_rc_weight (float): the weight of the right context in the attention mechanism during unified ASR training. + Only relevant if conv_context_style is 'dcc_rc'. + Defaults to None. conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. Defaults to False. @@ -305,13 +315,18 @@ def __init__( n_heads=4, att_context_size=None, att_context_probs=None, + att_chunk_context_size=None, att_context_style='regular', + att_zero_rc_weight=None, + dual_mode_training=False, + unified_asr_prob=None, xscaling=True, untie_biases=True, pos_emb_max_len=5000, conv_kernel_size=31, conv_norm_type='batch_norm', conv_context_size=None, + conv_context_style='regular', use_bias=True, dropout=0.1, dropout_pre_encoder=0.1, @@ -346,6 +361,30 @@ def __init__( self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends self.sync_max_audio_length = sync_max_audio_length + + assert conv_context_style in ["regular", "dcc", "dcc_rc"], f"Invalid conv_context_style: {conv_context_style}!" + self.conv_context_style = conv_context_style + self.conv_kernel_size = conv_kernel_size + + # Setting up the att_chunk_context_size + self.dual_mode_training = dual_mode_training + self.unified_asr_prob = unified_asr_prob + if att_chunk_context_size is not None: + assert att_context_style == "chunked_limited_with_rc", "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!" + assert len(att_chunk_context_size) == 3, "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]" + self.att_chunk_context_size = att_chunk_context_size + else: + self.att_chunk_context_size = None + + # setting up att_rc_weigts: + if att_zero_rc_weight is not None and self.att_chunk_context_size[2][0] == 0: + assert 0 <= att_zero_rc_weight <= 1, "att_zero_rc_weight must be between 0 and 1!" + non_zero_rc_weight = (1 - att_zero_rc_weight) / len(self.att_chunk_context_size[2][1:]) + self.att_rc_weights = [non_zero_rc_weight] * len(self.att_chunk_context_size[2]) + self.att_rc_weights[0] = att_zero_rc_weight + else: + self.att_rc_weights = None + # Setting up the att_context_size ( self.att_context_size_all, @@ -716,7 +755,6 @@ def forward_internal( offset=offset, device=audio_signal.device, ) - # saving tensors if required for interctc loss if self.is_access_enabled(getattr(self, "model_guid", None)): if self.interctc_capture_at_layers is None: @@ -816,6 +854,37 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) ) att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) + elif self.att_context_style == "chunked_limited_with_rc" and sum(att_context_size) != -3: + assert len(att_context_size) == 3, "att_context_size must have 3 elements: [left_context, chunk_size, right_context]" + + left_context_frames = att_context_size[0] + chunk_size_frames = att_context_size[1] + right_context_frames = att_context_size[2] + # Calculate chunk index for each frame (which processing group it belongs to) + frame_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) + chunk_idx = torch.div(frame_idx, chunk_size_frames, rounding_mode="trunc") + + window_start = chunk_idx * chunk_size_frames - left_context_frames + window_start = torch.maximum(window_start, torch.zeros_like(window_start)) + window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames + + if self.training and self.skip_att_chunk_rc_prob > 0.0: + chunks_num = max_audio_length // chunk_size_frames + for chunk_step in range(chunks_num): + if random.random() < self.skip_att_chunk_rc_prob: + window_end[chunk_step*chunk_size_frames:chunk_step*chunk_size_frames+chunk_size_frames] -= right_context_frames + + window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1)) + # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i] + j_indices = frame_idx.unsqueeze(0) # [1, T] + window_start_expanded = window_start.unsqueeze(1) # [T, 1] + window_end_expanded = window_end.unsqueeze(1) # [T, 1] + + chunked_limited_mask = torch.logical_and( + j_indices >= window_start_expanded, + j_indices <= window_end_expanded + ) + att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) else: att_mask = None @@ -876,6 +945,9 @@ def _calc_context_sizes( else: att_context_size_all = [[-1, -1]] + if att_context_style == "chunked_limited_with_rc": + att_context_size_all = [[-1, -1, -1]] + if att_context_probs: if len(att_context_probs) != len(att_context_size_all): raise ValueError("The size of the att_context_probs should be the same as att_context_size.") @@ -955,6 +1027,9 @@ def setup_streaming_params( elif self.att_context_style == "chunked_limited": lookahead_steps = att_context_size[1] streaming_cfg.cache_drop_size = 0 + elif self.att_context_style == "chunked_limited_with_rc": + lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers + streaming_cfg.cache_drop_size = 0 elif self.att_context_style == "regular": lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers streaming_cfg.cache_drop_size = lookahead_steps From a032051b0cdc1c91b827fc5c0a674f2e5f8d2095 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 19 Mar 2026 12:49:39 +0000 Subject: [PATCH 2/5] Apply isort and black reformatting Signed-off-by: andrusenkoau --- .../speech_to_text_streaming_infer_rnnt.py | 6 ++-- nemo/collections/asr/losses/rnnt.py | 9 +++--- .../asr/modules/conformer_encoder.py | 28 +++++++++++-------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 035b84775b73..894e9dd77ad2 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -133,7 +133,9 @@ class TranscriptionConfig: ) right_context_secs: float = 2 # right context - att_context_size_as_chunk: bool = True # whether to use the att_context_size as chunk size (importand for extra-low latency) + att_context_size_as_chunk: bool = ( + True # whether to use the att_context_size as chunk size (importand for extra-low latency) + ) # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. @@ -303,7 +305,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: # unified ASR model: use the att_context_size as chunk size (important for extra-low latency) if asr_model.cfg.encoder.att_context_style == 'chunked_limited_with_rc' and cfg.att_context_size_as_chunk: asr_model.encoder.set_default_att_context_size( - att_context_size=[context_encoder_frames.left,context_encoder_frames.chunk,context_encoder_frames.right] + att_context_size=[context_encoder_frames.left, context_encoder_frames.chunk, context_encoder_frames.right] ) logging.info( diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 5eb910424342..0777fdf3a90e 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -154,7 +154,7 @@ class RNNTLossConfig: installation_msg="Pure Pytorch implementation of TDT loss. Slow and for debugging purposes only.", ), "rnnt_triton": RNNTLossConfig( - loss_name="rnnt_triton", # will be added later + loss_name="rnnt_triton", # will be added later lib_name="torch", min_version='0.0', is_available=True, @@ -331,7 +331,7 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphWTransducerLoss.__init__, ignore_params={"blank"}) loss_func = GraphWTransducerLoss(blank=blank_idx, **loss_kwargs) elif loss_name == "rnnt_triton": - loss_func = None # will be added later + loss_func = None # will be added later else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" @@ -343,8 +343,7 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) class RNNTLoss(Loss): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for CTCLoss.""" return { "log_probs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "targets": NeuralType(('B', 'T'), LabelsType()), @@ -405,7 +404,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = standard blank, and the standard blank is the last symbol in the vocab) TDT: num_classes = V. Note, V here does not include any of the "duration outputs". - reduction: Type of reduction to perform on loss. Possible values are + reduction: Type of reduction to perform on loss. Possible values are `mean_batch`, 'mean_volume`, `mean`, `sum` or None. `None` will return a torch vector comprising the individual loss values of the batch. `mean_batch` will average the losses in the batch diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 15c52d3b459a..69dfa309578b 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -361,7 +361,6 @@ def __init__( self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends self.sync_max_audio_length = sync_max_audio_length - assert conv_context_style in ["regular", "dcc", "dcc_rc"], f"Invalid conv_context_style: {conv_context_style}!" self.conv_context_style = conv_context_style self.conv_kernel_size = conv_kernel_size @@ -370,8 +369,12 @@ def __init__( self.dual_mode_training = dual_mode_training self.unified_asr_prob = unified_asr_prob if att_chunk_context_size is not None: - assert att_context_style == "chunked_limited_with_rc", "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!" - assert len(att_chunk_context_size) == 3, "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]" + assert ( + att_context_style == "chunked_limited_with_rc" + ), "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!" + assert ( + len(att_chunk_context_size) == 3 + ), "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]" self.att_chunk_context_size = att_chunk_context_size else: self.att_chunk_context_size = None @@ -855,8 +858,10 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs ) att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) elif self.att_context_style == "chunked_limited_with_rc" and sum(att_context_size) != -3: - assert len(att_context_size) == 3, "att_context_size must have 3 elements: [left_context, chunk_size, right_context]" - + assert ( + len(att_context_size) == 3 + ), "att_context_size must have 3 elements: [left_context, chunk_size, right_context]" + left_context_frames = att_context_size[0] chunk_size_frames = att_context_size[1] right_context_frames = att_context_size[2] @@ -867,22 +872,23 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs window_start = chunk_idx * chunk_size_frames - left_context_frames window_start = torch.maximum(window_start, torch.zeros_like(window_start)) window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames - + if self.training and self.skip_att_chunk_rc_prob > 0.0: chunks_num = max_audio_length // chunk_size_frames for chunk_step in range(chunks_num): if random.random() < self.skip_att_chunk_rc_prob: - window_end[chunk_step*chunk_size_frames:chunk_step*chunk_size_frames+chunk_size_frames] -= right_context_frames - + window_end[ + chunk_step * chunk_size_frames : chunk_step * chunk_size_frames + chunk_size_frames + ] -= right_context_frames + window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1)) # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i] j_indices = frame_idx.unsqueeze(0) # [1, T] window_start_expanded = window_start.unsqueeze(1) # [T, 1] - window_end_expanded = window_end.unsqueeze(1) # [T, 1] + window_end_expanded = window_end.unsqueeze(1) # [T, 1] chunked_limited_mask = torch.logical_and( - j_indices >= window_start_expanded, - j_indices <= window_end_expanded + j_indices >= window_start_expanded, j_indices <= window_end_expanded ) att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) else: From 7d8ce9f51f1ef0b1c111692dd3f6576699da5b30 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Tue, 24 Mar 2026 23:56:42 -0700 Subject: [PATCH 3/5] minor fixes Signed-off-by: andrusenkoau --- nemo/collections/asr/losses/rnnt.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 0777fdf3a90e..6c9488e77c21 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -153,14 +153,6 @@ class RNNTLossConfig: is_available=True, installation_msg="Pure Pytorch implementation of TDT loss. Slow and for debugging purposes only.", ), - "rnnt_triton": RNNTLossConfig( - loss_name="rnnt_triton", # will be added later - lib_name="torch", - min_version='0.0', - is_available=True, - installation_msg="Triton RNN-T loss", - force_float32=False, - ), } RNNT_LOSS_RESOLVER['default'] = RNNT_LOSS_RESOLVER['warprnnt_numba'] @@ -330,8 +322,6 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) elif loss_name == "graph_w_transducer": loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphWTransducerLoss.__init__, ignore_params={"blank"}) loss_func = GraphWTransducerLoss(blank=blank_idx, **loss_kwargs) - elif loss_name == "rnnt_triton": - loss_func = None # will be added later else: raise ValueError( f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}" @@ -472,9 +462,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): self._fp16_compat_checked = True # Upcast the activation tensor and compute loss and grads in fp32 - logits_orig = log_probs log_probs = log_probs.float() - del logits_orig # save memory *before* computing the loss # Ensure that shape mismatch does not occur due to padding # Due to padding and subsequent downsampling, it may be possible that From 714441d45a8be2584637139417a9f4e6074961c8 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 25 Mar 2026 02:57:31 -0700 Subject: [PATCH 4/5] fixes Signed-off-by: andrusenkoau --- .../rnnt/speech_to_text_streaming_infer_rnnt.py | 2 +- .../collections/asr/modules/conformer_encoder.py | 16 ++++------------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 894e9dd77ad2..1fda91b4ad70 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -134,7 +134,7 @@ class TranscriptionConfig: right_context_secs: float = 2 # right context att_context_size_as_chunk: bool = ( - True # whether to use the att_context_size as chunk size (importand for extra-low latency) + True # whether to use the att_context_size as chunk size (important for extra-low latency) ) # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 69dfa309578b..1f277ca44fd1 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -115,7 +115,7 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): It defined the range of Left, Middle, and Right context sizes for the attention mechanism. At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes. Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation - att_context_style (str): 'regular' or 'chunked_limited'. + att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'. Defaults to 'regular' xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`. Defaults to True. @@ -130,11 +130,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means `[(conv_kernel_size-1), 0]`. Defaults to None. - conv_context_style (str): 'regular' or 'dcc' + conv_context_style (str): 'regular', 'dcc', or 'dcc_rc' DCC - Dynamic Chunked Convolution that is used for unified ASR training. Defaults to 'regular'. att_zero_rc_weight (float): the weight of the right context in the attention mechanism during unified ASR training. - Only relevant if conv_context_style is 'dcc_rc'. + Only relevant if att_context_style is 'chunked_limited_with_rc' Defaults to None. conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. @@ -379,7 +379,7 @@ def __init__( else: self.att_chunk_context_size = None - # setting up att_rc_weigts: + # setting up att_rc_weights: if att_zero_rc_weight is not None and self.att_chunk_context_size[2][0] == 0: assert 0 <= att_zero_rc_weight <= 1, "att_zero_rc_weight must be between 0 and 1!" non_zero_rc_weight = (1 - att_zero_rc_weight) / len(self.att_chunk_context_size[2][1:]) @@ -873,14 +873,6 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs window_start = torch.maximum(window_start, torch.zeros_like(window_start)) window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames - if self.training and self.skip_att_chunk_rc_prob > 0.0: - chunks_num = max_audio_length // chunk_size_frames - for chunk_step in range(chunks_num): - if random.random() < self.skip_att_chunk_rc_prob: - window_end[ - chunk_step * chunk_size_frames : chunk_step * chunk_size_frames + chunk_size_frames - ] -= right_context_frames - window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1)) # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i] j_indices = frame_idx.unsqueeze(0) # [1, T] From b1de4227d9893450453d1745424720e4ba34db5a Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Mon, 30 Mar 2026 20:45:07 -0700 Subject: [PATCH 5/5] minor fixes Signed-off-by: andrusenkoau --- nemo/collections/asr/losses/rnnt.py | 4 +-- .../asr/modules/conformer_encoder.py | 26 ++++--------------- 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 6c9488e77c21..7c99eb3d86f2 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -333,7 +333,7 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) class RNNTLoss(Loss): @property def input_types(self): - """Input types definitions for CTCLoss.""" + """Input types definitions for RNNTLoss.""" return { "log_probs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "targets": NeuralType(('B', 'T'), LabelsType()), @@ -343,7 +343,7 @@ def input_types(self): @property def output_types(self): - """Output types definitions for CTCLoss. + """Output types definitions for RNNTLoss. loss: NeuralType(None) """ diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 1f277ca44fd1..2c3e355dac83 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -112,7 +112,7 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): when a list of them is passed. If not specified, uniform distribution is being used. Defaults to None att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training. - It defined the range of Left, Middle, and Right context sizes for the attention mechanism. + It defines the range of Left, Middle, and Right context sizes for the attention mechanism. At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes. Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'. @@ -130,12 +130,9 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means `[(conv_kernel_size-1), 0]`. Defaults to None. - conv_context_style (str): 'regular', 'dcc', or 'dcc_rc' + conv_context_style (str): 'regular' or 'dcc' DCC - Dynamic Chunked Convolution that is used for unified ASR training. Defaults to 'regular'. - att_zero_rc_weight (float): the weight of the right context in the attention mechanism during unified ASR training. - Only relevant if att_context_style is 'chunked_limited_with_rc' - Defaults to None. conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. Defaults to False. @@ -317,9 +314,6 @@ def __init__( att_context_probs=None, att_chunk_context_size=None, att_context_style='regular', - att_zero_rc_weight=None, - dual_mode_training=False, - unified_asr_prob=None, xscaling=True, untie_biases=True, pos_emb_max_len=5000, @@ -361,13 +355,11 @@ def __init__( self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends self.sync_max_audio_length = sync_max_audio_length - assert conv_context_style in ["regular", "dcc", "dcc_rc"], f"Invalid conv_context_style: {conv_context_style}!" + assert conv_context_style in ["regular", "dcc"], f"Invalid conv_context_style: {conv_context_style}!" self.conv_context_style = conv_context_style self.conv_kernel_size = conv_kernel_size # Setting up the att_chunk_context_size - self.dual_mode_training = dual_mode_training - self.unified_asr_prob = unified_asr_prob if att_chunk_context_size is not None: assert ( att_context_style == "chunked_limited_with_rc" @@ -379,15 +371,6 @@ def __init__( else: self.att_chunk_context_size = None - # setting up att_rc_weights: - if att_zero_rc_weight is not None and self.att_chunk_context_size[2][0] == 0: - assert 0 <= att_zero_rc_weight <= 1, "att_zero_rc_weight must be between 0 and 1!" - non_zero_rc_weight = (1 - att_zero_rc_weight) / len(self.att_chunk_context_size[2][1:]) - self.att_rc_weights = [non_zero_rc_weight] * len(self.att_chunk_context_size[2]) - self.att_rc_weights[0] = att_zero_rc_weight - else: - self.att_rc_weights = None - # Setting up the att_context_size ( self.att_context_size_all, @@ -865,6 +848,7 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs left_context_frames = att_context_size[0] chunk_size_frames = att_context_size[1] right_context_frames = att_context_size[2] + assert chunk_size_frames >= 1, "chunk_size_frames must be greater than 0!" # Calculate chunk index for each frame (which processing group it belongs to) frame_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) chunk_idx = torch.div(frame_idx, chunk_size_frames, rounding_mode="trunc") @@ -1026,7 +1010,7 @@ def setup_streaming_params( lookahead_steps = att_context_size[1] streaming_cfg.cache_drop_size = 0 elif self.att_context_style == "chunked_limited_with_rc": - lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers + lookahead_steps = att_context_size[2] * self.n_layers + self.conv_context_size[1] * self.n_layers streaming_cfg.cache_drop_size = 0 elif self.att_context_style == "regular": lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers