diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 298a3d5b4b55..1fda91b4ad70 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -133,6 +133,10 @@ class TranscriptionConfig: ) right_context_secs: float = 2 # right context + att_context_size_as_chunk: bool = ( + True # whether to use the att_context_size as chunk size (important for extra-low latency) + ) + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. @@ -298,6 +302,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples, ) + # unified ASR model: use the att_context_size as chunk size (important for extra-low latency) + if asr_model.cfg.encoder.att_context_style == 'chunked_limited_with_rc' and cfg.att_context_size_as_chunk: + asr_model.encoder.set_default_att_context_size( + att_context_size=[context_encoder_frames.left, context_encoder_frames.chunk, context_encoder_frames.right] + ) + logging.info( "Corrected contexts (sec): " f"Left {context_samples.left / audio_sample_rate:.2f}, " diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 894be6319c99..7c99eb3d86f2 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -333,8 +333,7 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None) class RNNTLoss(Loss): @property def input_types(self): - """Input types definitions for CTCLoss. - """ + """Input types definitions for RNNTLoss.""" return { "log_probs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), "targets": NeuralType(('B', 'T'), LabelsType()), @@ -344,7 +343,7 @@ def input_types(self): @property def output_types(self): - """Output types definitions for CTCLoss. + """Output types definitions for RNNTLoss. loss: NeuralType(None) """ @@ -395,7 +394,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = standard blank, and the standard blank is the last symbol in the vocab) TDT: num_classes = V. Note, V here does not include any of the "duration outputs". - reduction: Type of reduction to perform on loss. Possible values are + reduction: Type of reduction to perform on loss. Possible values are `mean_batch`, 'mean_volume`, `mean`, `sum` or None. `None` will return a torch vector comprising the individual loss values of the batch. `mean_batch` will average the losses in the batch @@ -463,9 +462,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): self._fp16_compat_checked = True # Upcast the activation tensor and compute loss and grads in fp32 - logits_orig = log_probs log_probs = log_probs.float() - del logits_orig # save memory *before* computing the loss # Ensure that shape mismatch does not occur due to padding # Due to padding and subsequent downsampling, it may be possible that diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 0e25ea7767f4..2c3e355dac83 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -111,7 +111,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): att_context_probs (List[float]): a list of probabilities of each one of the att_context_size when a list of them is passed. If not specified, uniform distribution is being used. Defaults to None - att_context_style (str): 'regular' or 'chunked_limited'. + att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training. + It defines the range of Left, Middle, and Right context sizes for the attention mechanism. + At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes. + Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation + att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'. Defaults to 'regular' xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`. Defaults to True. @@ -126,6 +130,9 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): `None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means `[(conv_kernel_size-1), 0]`. Defaults to None. + conv_context_style (str): 'regular' or 'dcc' + DCC - Dynamic Chunked Convolution that is used for unified ASR training. + Defaults to 'regular'. conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used. When enables, the left half of the convolution kernel would get masked in streaming cases. Defaults to False. @@ -305,6 +312,7 @@ def __init__( n_heads=4, att_context_size=None, att_context_probs=None, + att_chunk_context_size=None, att_context_style='regular', xscaling=True, untie_biases=True, @@ -312,6 +320,7 @@ def __init__( conv_kernel_size=31, conv_norm_type='batch_norm', conv_context_size=None, + conv_context_style='regular', use_bias=True, dropout=0.1, dropout_pre_encoder=0.1, @@ -346,6 +355,22 @@ def __init__( self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends self.sync_max_audio_length = sync_max_audio_length + assert conv_context_style in ["regular", "dcc"], f"Invalid conv_context_style: {conv_context_style}!" + self.conv_context_style = conv_context_style + self.conv_kernel_size = conv_kernel_size + + # Setting up the att_chunk_context_size + if att_chunk_context_size is not None: + assert ( + att_context_style == "chunked_limited_with_rc" + ), "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!" + assert ( + len(att_chunk_context_size) == 3 + ), "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]" + self.att_chunk_context_size = att_chunk_context_size + else: + self.att_chunk_context_size = None + # Setting up the att_context_size ( self.att_context_size_all, @@ -716,7 +741,6 @@ def forward_internal( offset=offset, device=audio_signal.device, ) - # saving tensors if required for interctc loss if self.is_access_enabled(getattr(self, "model_guid", None)): if self.interctc_capture_at_layers is None: @@ -816,6 +840,33 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) ) att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) + elif self.att_context_style == "chunked_limited_with_rc" and sum(att_context_size) != -3: + assert ( + len(att_context_size) == 3 + ), "att_context_size must have 3 elements: [left_context, chunk_size, right_context]" + + left_context_frames = att_context_size[0] + chunk_size_frames = att_context_size[1] + right_context_frames = att_context_size[2] + assert chunk_size_frames >= 1, "chunk_size_frames must be greater than 0!" + # Calculate chunk index for each frame (which processing group it belongs to) + frame_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device) + chunk_idx = torch.div(frame_idx, chunk_size_frames, rounding_mode="trunc") + + window_start = chunk_idx * chunk_size_frames - left_context_frames + window_start = torch.maximum(window_start, torch.zeros_like(window_start)) + window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames + + window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1)) + # Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i] + j_indices = frame_idx.unsqueeze(0) # [1, T] + window_start_expanded = window_start.unsqueeze(1) # [T, 1] + window_end_expanded = window_end.unsqueeze(1) # [T, 1] + + chunked_limited_mask = torch.logical_and( + j_indices >= window_start_expanded, j_indices <= window_end_expanded + ) + att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) else: att_mask = None @@ -876,6 +927,9 @@ def _calc_context_sizes( else: att_context_size_all = [[-1, -1]] + if att_context_style == "chunked_limited_with_rc": + att_context_size_all = [[-1, -1, -1]] + if att_context_probs: if len(att_context_probs) != len(att_context_size_all): raise ValueError("The size of the att_context_probs should be the same as att_context_size.") @@ -955,6 +1009,9 @@ def setup_streaming_params( elif self.att_context_style == "chunked_limited": lookahead_steps = att_context_size[1] streaming_cfg.cache_drop_size = 0 + elif self.att_context_style == "chunked_limited_with_rc": + lookahead_steps = att_context_size[2] * self.n_layers + self.conv_context_size[1] * self.n_layers + streaming_cfg.cache_drop_size = 0 elif self.att_context_style == "regular": lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers streaming_cfg.cache_drop_size = lookahead_steps