Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class TranscriptionConfig:
)
right_context_secs: float = 2 # right context

att_context_size_as_chunk: bool = (
True # whether to use the att_context_size as chunk size (important for extra-low latency)
)

# Set `cuda` to int to define CUDA device. If 'None', will look for CUDA
# device anyway, and do inference on CPU only if CUDA device is not found.
# If `cuda` is a negative number, inference will be on CPU only.
Expand Down Expand Up @@ -298,6 +302,12 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
right=context_encoder_frames.right * encoder_subsampling_factor * features_frame2audio_samples,
)

# unified ASR model: use the att_context_size as chunk size (important for extra-low latency)
if asr_model.cfg.encoder.att_context_style == 'chunked_limited_with_rc' and cfg.att_context_size_as_chunk:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this also to nemo/collections/asr/inference?

asr_model.encoder.set_default_att_context_size(
att_context_size=[context_encoder_frames.left, context_encoder_frames.chunk, context_encoder_frames.right]
)

logging.info(
"Corrected contexts (sec): "
f"Left {context_samples.left / audio_sample_rate:.2f}, "
Expand Down
9 changes: 3 additions & 6 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,7 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
class RNNTLoss(Loss):
@property
def input_types(self):
"""Input types definitions for CTCLoss.
"""
"""Input types definitions for RNNTLoss."""
return {
"log_probs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()),
"targets": NeuralType(('B', 'T'), LabelsType()),
Expand All @@ -344,7 +343,7 @@ def input_types(self):

@property
def output_types(self):
"""Output types definitions for CTCLoss.
"""Output types definitions for RNNTLoss.
loss:
NeuralType(None)
"""
Expand Down Expand Up @@ -395,7 +394,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
standard blank, and the standard blank is the last symbol in the vocab)
TDT: num_classes = V. Note, V here does not include any of the "duration outputs".

reduction: Type of reduction to perform on loss. Possible values are
reduction: Type of reduction to perform on loss. Possible values are
`mean_batch`, 'mean_volume`, `mean`, `sum` or None.
`None` will return a torch vector comprising the individual loss values of the batch.
`mean_batch` will average the losses in the batch
Expand Down Expand Up @@ -463,9 +462,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
self._fp16_compat_checked = True

# Upcast the activation tensor and compute loss and grads in fp32
logits_orig = log_probs
log_probs = log_probs.float()
del logits_orig # save memory *before* computing the loss

# Ensure that shape mismatch does not occur due to padding
# Due to padding and subsequent downsampling, it may be possible that
Expand Down
61 changes: 59 additions & 2 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
att_context_probs (List[float]): a list of probabilities of each one of the att_context_size
when a list of them is passed. If not specified, uniform distribution is being used.
Defaults to None
att_context_style (str): 'regular' or 'chunked_limited'.
att_chunk_context_size (List[List[int]]): specifies the context sizes for unified (offline/streaming) ASR training.
It defines the range of Left, Middle, and Right context sizes for the attention mechanism.
At each streaming step, the context size is sampled from the range of Left, Middle, and Right context sizes.
Example: att_chunk_context_size=[[70],[1,2,7,13],[0,1,3,7,13]] -> sampling -> [70, 2, 3] -> attention mask generation
att_context_style (str): 'regular', 'chunked_limited', or 'chunked_limited_with_rc'.
Defaults to 'regular'
xscaling (bool): enables scaling the inputs to the multi-headed attention layers by `sqrt(d_model)`.
Defaults to True.
Expand All @@ -126,6 +130,9 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
`None` means `[(conv_kernel_size-1)//2`, `(conv_kernel_size-1)//2]`, and 'causal' means
`[(conv_kernel_size-1), 0]`.
Defaults to None.
conv_context_style (str): 'regular' or 'dcc'
DCC - Dynamic Chunked Convolution that is used for unified ASR training.
Defaults to 'regular'.
conv_dual_mode (bool): specifies if convolution should be dual mode when dual_offline mode is being used.
When enables, the left half of the convolution kernel would get masked in streaming cases.
Defaults to False.
Expand Down Expand Up @@ -305,13 +312,15 @@ def __init__(
n_heads=4,
att_context_size=None,
att_context_probs=None,
att_chunk_context_size=None,
att_context_style='regular',
xscaling=True,
untie_biases=True,
pos_emb_max_len=5000,
conv_kernel_size=31,
conv_norm_type='batch_norm',
conv_context_size=None,
conv_context_style='regular',
use_bias=True,
dropout=0.1,
dropout_pre_encoder=0.1,
Expand Down Expand Up @@ -346,6 +355,22 @@ def __init__(
self.use_pytorch_sdpa_backends = use_pytorch_sdpa_backends
self.sync_max_audio_length = sync_max_audio_length

assert conv_context_style in ["regular", "dcc"], f"Invalid conv_context_style: {conv_context_style}!"
self.conv_context_style = conv_context_style
self.conv_kernel_size = conv_kernel_size

# Setting up the att_chunk_context_size
if att_chunk_context_size is not None:
assert (
att_context_style == "chunked_limited_with_rc"
), "att_chunk_context_size is only supported for chunked_limited_with_rc attention style!"
assert (
len(att_chunk_context_size) == 3
), "att_chunk_context_size must have 3 elements: [left_context, chunk_size, right_context]"
self.att_chunk_context_size = att_chunk_context_size
else:
self.att_chunk_context_size = None

# Setting up the att_context_size
(
self.att_context_size_all,
Expand Down Expand Up @@ -716,7 +741,6 @@ def forward_internal(
offset=offset,
device=audio_signal.device,
)

# saving tensors if required for interctc loss
if self.is_access_enabled(getattr(self, "model_guid", None)):
if self.interctc_capture_at_layers is None:
Expand Down Expand Up @@ -816,6 +840,33 @@ def _create_masks(self, att_context_size, padding_length, max_audio_length, offs
torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
)
att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0))
elif self.att_context_style == "chunked_limited_with_rc" and sum(att_context_size) != -3:
assert (
len(att_context_size) == 3
), "att_context_size must have 3 elements: [left_context, chunk_size, right_context]"

left_context_frames = att_context_size[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, separate to a different function if possible

chunk_size_frames = att_context_size[1]
right_context_frames = att_context_size[2]
assert chunk_size_frames >= 1, "chunk_size_frames must be greater than 0!"
# Calculate chunk index for each frame (which processing group it belongs to)
frame_idx = torch.arange(0, max_audio_length, dtype=torch.int, device=att_mask.device)
chunk_idx = torch.div(frame_idx, chunk_size_frames, rounding_mode="trunc")

Comment on lines +848 to +855
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunk_size_frames is used as a divisor (torch.div(frame_idx, chunk_size_frames, ...)) but there's no validation that it is >= 1. If a config or inference script produces a 0-frame chunk, this will crash with division-by-zero; add an explicit check and raise a clear error for invalid chunk sizes.

Copilot uses AI. Check for mistakes.
Comment on lines +848 to +855
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the chunked_limited_with_rc mask generation, chunk_size_frames is used as a divisor (torch.div(frame_idx, chunk_size_frames, ...)) and as a window size. If att_context_size[1] is 0 (or negative), this will raise at runtime; if left_context_frames/right_context_frames are negative, the window math becomes incorrect. Please add explicit validation (e.g., chunk >= 1, left/right >= 0) or implement the intended semantics for negative values.

Copilot uses AI. Check for mistakes.
window_start = chunk_idx * chunk_size_frames - left_context_frames
window_start = torch.maximum(window_start, torch.zeros_like(window_start))
window_end = chunk_idx * chunk_size_frames + chunk_size_frames - 1 + right_context_frames

window_end = torch.minimum(window_end, torch.full_like(window_end, max_audio_length - 1))
# Create the mask: frame i can see frame j if window_start[i] <= j <= window_end[i]
j_indices = frame_idx.unsqueeze(0) # [1, T]
window_start_expanded = window_start.unsqueeze(1) # [T, 1]
window_end_expanded = window_end.unsqueeze(1) # [T, 1]

chunked_limited_mask = torch.logical_and(
j_indices >= window_start_expanded, j_indices <= window_end_expanded
)
att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0))
Comment on lines +843 to +869
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New attention style chunked_limited_with_rc introduces non-trivial masking logic and affects streaming configuration, but there are existing unit tests for ConformerEncoder (e.g., tests/collections/asr/test_conformer_encoder.py) that don't cover this style. Please add tests that exercise _create_masks()/setup_streaming_params() for chunked_limited_with_rc (including at least one representative [left, chunk, right] setting) to prevent regressions.

Copilot uses AI. Check for mistakes.
else:
att_mask = None

Expand Down Expand Up @@ -876,6 +927,9 @@ def _calc_context_sizes(
else:
att_context_size_all = [[-1, -1]]

if att_context_style == "chunked_limited_with_rc":
att_context_size_all = [[-1, -1, -1]]
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For att_context_style == "chunked_limited_with_rc", _calc_context_sizes() overwrites any user-provided att_context_size with [[-1, -1, -1]], so config-specified context sizes are silently ignored and set_default_att_context_size() will warn even for valid sizes. Consider only defaulting to [-1, -1, -1] when att_context_size is not provided, and otherwise validate that each entry has 3 elements.

Suggested change
att_context_size_all = [[-1, -1, -1]]
if not att_context_size:
# Default to unlimited left, right, and relative context when none is provided.
att_context_size_all = [[-1, -1, -1]]
else:
# Validate that each entry has exactly 3 elements for chunked_limited_with_rc style.
for i, att_cs in enumerate(att_context_size_all):
if len(att_cs) != 3:
raise ValueError(
f"att_context_size[{i}] must have exactly 3 elements for chunked_limited_with_rc style."
)

Copilot uses AI. Check for mistakes.

Comment on lines +930 to +932
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For att_context_style == 'chunked_limited_with_rc', _calc_context_sizes() currently overwrites any provided att_context_size with [[-1, -1, -1]], which makes the encoder ignore configured context sizes and causes set_default_att_context_size() to always warn that the provided size is unsupported. Consider parsing/validating the provided att_context_size similarly to the other styles, and only defaulting to [-1, -1, -1] when no value was provided.

Copilot uses AI. Check for mistakes.
if att_context_probs:
if len(att_context_probs) != len(att_context_size_all):
raise ValueError("The size of the att_context_probs should be the same as att_context_size.")
Expand Down Expand Up @@ -955,6 +1009,9 @@ def setup_streaming_params(
elif self.att_context_style == "chunked_limited":
lookahead_steps = att_context_size[1]
streaming_cfg.cache_drop_size = 0
elif self.att_context_style == "chunked_limited_with_rc":
lookahead_steps = att_context_size[2] * self.n_layers + self.conv_context_size[1] * self.n_layers
streaming_cfg.cache_drop_size = 0
elif self.att_context_style == "regular":
lookahead_steps = att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers
streaming_cfg.cache_drop_size = lookahead_steps
Expand Down
Loading