diff --git a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py index 291705a54f1e..2c2f4c06cd05 100644 --- a/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/multispk_transcribe_utils.py @@ -1416,6 +1416,8 @@ def __init__( self._uppercase_first_letter = uppercase_first_letter self._speaker_wise_sentences = {} self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] + self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)] + self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)] self.seglsts = [] @@ -1425,6 +1427,8 @@ def _reset_speaker_wise_sentences(self): """ self._speaker_wise_sentences = {} self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] + self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)] + self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)] def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): """ @@ -1443,6 +1447,8 @@ def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] self.seglsts = [] self._speaker_wise_sentences = {} self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)] + self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)] + self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)] def update_asr_state( self, @@ -1515,7 +1521,10 @@ def _update_last_sentence(self, spk_idx: int, end_time: float, diff_text: str): diff_text (str): The difference text. """ if end_time is not None: - self._speaker_wise_sentences[spk_idx][-1]['end_time'] = end_time + current_start = self._speaker_wise_sentences[spk_idx][-1]['start_time'] + self._speaker_wise_sentences[spk_idx][-1]['end_time'] = max( + end_time, current_start + self._frame_len_sec + ) new_words = self._speaker_wise_sentences[spk_idx][-1]['words'] + diff_text self._speaker_wise_sentences[spk_idx][-1]['words'] = new_words.strip() @@ -1536,18 +1545,27 @@ def _is_new_text(self, spk_idx: int, text: str): else: return text.strip() - def _compute_hypothesis_timestamps(self, hypothesis: Hypothesis, offset: float) -> Tuple[float, float, bool]: + def _compute_hypothesis_timestamps( + self, + hypothesis: Hypothesis, + offset: float, + prev_token_count: int = 0, + decoded_length_before: int = None, + ) -> Tuple[float, float, bool]: """ Compute start and end timestamps for a hypothesis based on available timing information. This method calculates the temporal boundaries of a speech hypothesis, prioritizing - frame-level timestamps when available. When timestamps are not available, it falls - back to computing timing based on the hypothesis length. + frame-level timestamps and decoder state when available. When timestamps are not available, + it falls back to computing timing based on the hypothesis length. Args: - hypothesis (Hypothesis): The ASR hypothesis object containing either frame-level + hypothesis (Hypothesis): The ASR hypothesis object containing frame-level + timestamps and decoder state. offset (float): The time offset (in seconds) to add to the computed timestamps, typically representing the start time of the current audio chunk. + prev_token_count (int): The number of timestamp entries already processed for this speaker. + decoded_length_before (int): The decoded length before the current chunk. Returns: Tuple[float, float, bool]: A tuple containing: @@ -1561,12 +1579,15 @@ def _compute_hypothesis_timestamps(self, hypothesis: Hypothesis, offset: float) for the full duration of the final frame. """ sep_flag = False - if len(hypothesis.timestamp) > 0: - start_time = offset + (hypothesis.timestamp[0]) * self._frame_len_sec - end_time = offset + (hypothesis.timestamp[-1] + 1) * self._frame_len_sec + new_timestamp_count = len(hypothesis.timestamp) - prev_token_count + if hypothesis.dec_state is not None and new_timestamp_count > 0 and decoded_length_before is not None: + start_local = hypothesis.timestamp[prev_token_count].item() - decoded_length_before + end_local = hypothesis.timestamp[-1].item() - decoded_length_before + start_time = offset + start_local * self._frame_len_sec + end_time = offset + (end_local + 1) * self._frame_len_sec else: start_time = offset - end_time = offset + hypothesis.length.item() * self._frame_len_sec + end_time = offset + max(0, hypothesis.length.item() - prev_token_count) * self._frame_len_sec sep_flag = True return start_time, end_time, sep_flag @@ -1594,9 +1615,16 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float): if diff_text is not None: start_time, end_time, sep_flag = self._compute_hypothesis_timestamps( - hypothesis=hypothesis, offset=offset + hypothesis=hypothesis, + offset=offset, + prev_token_count=self._prev_token_counts[spk_idx], + decoded_length_before=self._prev_decoded_lengths[spk_idx], ) + # Update the stored decoded_length for this speaker + if hypothesis.dec_state is not None: + self._prev_decoded_lengths[spk_idx] = hypothesis.dec_state.decoded_length.item() + # Get the last end time of the previous sentence or None if no sentences are present if len(self._speaker_wise_sentences[spk_idx]) > 0: last_end_time = self._speaker_wise_sentences[spk_idx][-1]['end_time'] @@ -1628,6 +1656,7 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float): # Update the previous history of the speaker text if hypothesis.text is not None: self._prev_history_speaker_texts[spk_idx] = hypothesis.text + self._prev_token_counts[spk_idx] = len(hypothesis.timestamp) self.seglsts = []