From 547846df89310e84b3427f4fa04090fa3de57d99 Mon Sep 17 00:00:00 2001 From: Bolor Date: Fri, 6 Mar 2026 13:23:36 -0800 Subject: [PATCH 1/7] adding scorer without pydub --- pyrit/score/audio_transcript_scorer.py | 127 +++++++++++++------------ 1 file changed, 67 insertions(+), 60 deletions(-) diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index b0d0ad2a92..31b6290671 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -1,8 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from ast import If import logging import os +import shutil +import subprocess import tempfile import uuid from abc import ABC @@ -16,6 +19,24 @@ logger = logging.getLogger(__name__) +def _check_ffmpeg_installed() -> bool: + """ + Check if ffmpeg is installed and available on PATH. + FFmpeg is required for scoring audio content in videos + + Returns: + bool: True if ffmpeg is installed, False otherwise. + """ + + if shutil.which("ffmpeg") is None: + # raise RuntimeError( + # "ffmpeg is required for audio processing but was not found on PATH. " + # "Install it via: apt install ffmpeg / brew install ffmpeg / " + # "https://ffmpeg.org/download.html" + # ) + return False + return True + class AudioTranscriptHelper(ABC): # noqa: B024 """ Abstract base class for audio scorers that process audio by transcribing and scoring the text. @@ -29,7 +50,6 @@ class AudioTranscriptHelper(ABC): # noqa: B024 _DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate _DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono _DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample) - _DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility def __init__( self, @@ -173,23 +193,30 @@ def _ensure_wav_format(self, audio_path: str) -> str: str: Path to WAV file (original if already WAV, or converted temporary file). Raises: - ModuleNotFoundError: If pydub is not installed. + RuntimeError: If ffmpeg is not installed. """ - try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e - - audio = AudioSegment.from_file(audio_path) - audio = ( - audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE) - .set_channels(self._DEFAULT_CHANNELS) - .set_sample_width(self._DEFAULT_SAMPLE_WIDTH) - ) + + if not _check_ffmpeg_installed(): + raise RuntimeError( + "ffmpeg is required for audio processing but was not found on PATH. " + "Install it via: apt install ffmpeg / brew install ffmpeg / " + "https://ffmpeg.org/download.html" + ) + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: - audio.export(temp_wav.name, format="wav") - return temp_wav.name + output_path = temp_wav.name + subprocess.run( + [ + "ffmpeg", "-i", audio_path, + "-ar", str(self._DEFAULT_SAMPLE_RATE), + "-ac", str(self._DEFAULT_CHANNELS), + "-acodec", "pcm_s16le", # 16-bit PCM + output_path, "-y", + ], + check=True, + capture_output=True, + ) + return output_path def _extract_audio_from_video(self, video_path: str) -> Optional[str]: """ @@ -203,7 +230,7 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: or returns None if extraction fails. Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. + RuntimeError: If ffmpeg is not installed. """ return AudioTranscriptHelper.extract_audio_from_video(video_path) @@ -220,55 +247,35 @@ def extract_audio_from_video(video_path: str) -> Optional[str]: or returns None if extraction fails. Raises: - ModuleNotFoundError: If pydub/ffmpeg is not installed. + RuntimeError: If ffmpeg is not installed. """ - try: - from pydub import AudioSegment - except ModuleNotFoundError as e: - logger.error("Could not import pydub. Install it via 'pip install pydub'") - raise e + if not _check_ffmpeg_installed(): + raise RuntimeError( + "ffmpeg is required for audio processing but was not found on PATH. " + "Install it via: apt install ffmpeg / brew install ffmpeg / " + "https://ffmpeg.org/download.html" + ) try: - # Extract audio from video using pydub (requires ffmpeg) logger.info(f"Extracting audio from video: {video_path}") - audio = AudioSegment.from_file(video_path) + with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio: + output_path = temp_audio.name + subprocess.run( + [ + "ffmpeg", "-i", video_path, + "-ar", str(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE), + "-ac", str(AudioTranscriptHelper._DEFAULT_CHANNELS), + "-acodec", "pcm_s16le", # 16-bit PCM + output_path, "-y", + ], + check=True, + capture_output=True, + ) logger.info( - f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, " - f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}" + f"Audio exported to: {output_path} " + f"(rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)" ) - - # Optimize for Azure Speech recognition: - # Azure Speech works best with 16kHz mono audio (same as Azure TTS output) - if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE: - logger.info( - f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz" - ) - audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE) - - # Ensure 16-bit audio - if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH: - logger.info( - f"Converting sample width from {audio.sample_width * 8}-bit" - f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit" - ) - audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH) - - # Convert to mono (Azure Speech prefers mono) - if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS: - logger.info(f"Converting from {audio.channels} channels to mono") - audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS) - - # Create temporary WAV file with PCM encoding for best compatibility - with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio: - audio.export( - temp_audio.name, - format="wav", - parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS, - ) - logger.info( - f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)" - ) - return temp_audio.name + return output_path except Exception as e: logger.warning(f"Failed to extract audio from video {video_path}: {e}") return None From 9f4d569769fb47dc7ff877600cd0115c83656aa0 Mon Sep 17 00:00:00 2001 From: Bolor Date: Sat, 7 Mar 2026 09:41:11 -0800 Subject: [PATCH 2/7] getting rid of pydub --- pyrit/score/audio_transcript_scorer.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 31b6290671..cb97be1404 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from ast import If import logging import os import shutil @@ -22,20 +21,11 @@ def _check_ffmpeg_installed() -> bool: """ Check if ffmpeg is installed and available on PATH. - FFmpeg is required for scoring audio content in videos Returns: bool: True if ffmpeg is installed, False otherwise. """ - - if shutil.which("ffmpeg") is None: - # raise RuntimeError( - # "ffmpeg is required for audio processing but was not found on PATH. " - # "Install it via: apt install ffmpeg / brew install ffmpeg / " - # "https://ffmpeg.org/download.html" - # ) - return False - return True + return shutil.which("ffmpeg") is not None class AudioTranscriptHelper(ABC): # noqa: B024 """ From 4780255ec35d42831a87ce990bb184a0e49356d3 Mon Sep 17 00:00:00 2001 From: Bolor Date: Sat, 7 Mar 2026 11:20:44 -0800 Subject: [PATCH 3/7] precommit --- pyrit/score/audio_transcript_scorer.py | 41 ++++++++++++++++---------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index cb97be1404..8e2630d20b 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -27,6 +27,7 @@ def _check_ffmpeg_installed() -> bool: """ return shutil.which("ffmpeg") is not None + class AudioTranscriptHelper(ABC): # noqa: B024 """ Abstract base class for audio scorers that process audio by transcribing and scoring the text. @@ -185,23 +186,28 @@ def _ensure_wav_format(self, audio_path: str) -> str: Raises: RuntimeError: If ffmpeg is not installed. """ - if not _check_ffmpeg_installed(): raise RuntimeError( "ffmpeg is required for audio processing but was not found on PATH. " "Install it via: apt install ffmpeg / brew install ffmpeg / " "https://ffmpeg.org/download.html" ) - + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: output_path = temp_wav.name subprocess.run( [ - "ffmpeg", "-i", audio_path, - "-ar", str(self._DEFAULT_SAMPLE_RATE), - "-ac", str(self._DEFAULT_CHANNELS), - "-acodec", "pcm_s16le", # 16-bit PCM - output_path, "-y", + "ffmpeg", + "-i", + audio_path, + "-ar", + str(self._DEFAULT_SAMPLE_RATE), + "-ac", + str(self._DEFAULT_CHANNELS), + "-acodec", + "pcm_s16le", # 16-bit PCM + output_path, + "-y", ], check=True, capture_output=True, @@ -252,19 +258,22 @@ def extract_audio_from_video(video_path: str) -> Optional[str]: output_path = temp_audio.name subprocess.run( [ - "ffmpeg", "-i", video_path, - "-ar", str(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE), - "-ac", str(AudioTranscriptHelper._DEFAULT_CHANNELS), - "-acodec", "pcm_s16le", # 16-bit PCM - output_path, "-y", + "ffmpeg", + "-i", + video_path, + "-ar", + str(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE), + "-ac", + str(AudioTranscriptHelper._DEFAULT_CHANNELS), + "-acodec", + "pcm_s16le", # 16-bit PCM + output_path, + "-y", ], check=True, capture_output=True, ) - logger.info( - f"Audio exported to: {output_path} " - f"(rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)" - ) + logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)") return output_path except Exception as e: logger.warning(f"Failed to extract audio from video {video_path}: {e}") From f80b6ba998f476ee26915af793277715697dadbc Mon Sep 17 00:00:00 2001 From: Bolor Date: Wed, 11 Mar 2026 15:10:28 -0700 Subject: [PATCH 4/7] replace with pyav --- pyproject.toml | 1 + pyrit/score/audio_transcript_scorer.py | 142 +++++++++++++------------ 2 files changed, 76 insertions(+), 67 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed9ab048ed..5e3d4ddced 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ speech = [ # all includes all functional dependencies excluding the ones from the "dev" extra all = [ "accelerate>=1.7.0", + "av>=14.0.0", "azure-ai-ml>=1.27.1", "azure-cognitiveservices-speech>=1.44.0", "azureml-mlflow>=1.60.0", diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 8e2630d20b..5d28cfb911 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -3,13 +3,13 @@ import logging import os -import shutil -import subprocess import tempfile import uuid from abc import ABC from typing import Optional +import av + from pyrit.memory import CentralMemory from pyrit.models import MessagePiece, Score from pyrit.prompt_converter import AzureSpeechAudioToTextConverter @@ -18,14 +18,74 @@ logger = logging.getLogger(__name__) -def _check_ffmpeg_installed() -> bool: +def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bool: + """ + Check if the audio file is already a compliant WAV with the target format. + + Args: + input_path (str): Path to the audio file. + sample_rate (int): Expected sample rate in Hz. + channels (int): Expected number of channels. + + Returns: + bool: True if the file is already compliant, False otherwise. + """ + try: + with av.open(input_path) as container: + if not container.streams.audio: + return False + stream = container.streams.audio[0] + codec_name = stream.codec_context.name + is_pcm_s16 = codec_name == "pcm_s16le" + is_correct_rate = stream.rate == sample_rate + is_correct_channels = stream.channels == channels + return is_pcm_s16 and is_correct_rate and is_correct_channels + except Exception: + return False + + +def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str: """ - Check if ffmpeg is installed and available on PATH. + Convert any audio or video file to a normalised PCM WAV using PyAV. + + If the input is already a compliant WAV (correct sample rate, channels, and codec), + returns the original path without re-encoding. + + Args: + input_path (str): Source audio or video file. + sample_rate (int): Target sample rate in Hz. + channels (int): Target number of channels (1 = mono). Returns: - bool: True if ffmpeg is installed, False otherwise. + str: Path to the WAV file (original if compliant, otherwise a temporary file). """ - return shutil.which("ffmpeg") is not None + # Skip conversion if already compliant + if _is_compliant_wav(input_path, sample_rate=sample_rate, channels=channels): + logger.debug(f"Audio file already compliant, skipping conversion: {input_path}") + return input_path + + layout = "mono" if channels == 1 else "stereo" + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + with av.open(input_path) as in_container: + with av.open(output_path, "w", format="wav") as out_container: + out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate, layout=layout) + resampler = av.AudioResampler(format="s16", layout=layout, rate=sample_rate) + + for frame in in_container.decode(audio=0): + for out_frame in resampler.resample(frame): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for out_frame in resampler.resample(None): + for packet in out_stream.encode(out_frame): + out_container.mux(packet) + + for packet in out_stream.encode(None): + out_container.mux(packet) + + return output_path class AudioTranscriptHelper(ABC): # noqa: B024 @@ -160,7 +220,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: - converter = AzureSpeechAudioToTextConverter() + converter = AzureSpeechAudioToTextConverter(use_entra_auth=True) logger.info("Audio transcription: Starting Azure Speech transcription...") result = await converter.convert_async(prompt=wav_path, input_type="audio_path") logger.info(f"Audio transcription: Result = '{result.output_text}'") @@ -182,37 +242,12 @@ def _ensure_wav_format(self, audio_path: str) -> str: Returns: str: Path to WAV file (original if already WAV, or converted temporary file). - - Raises: - RuntimeError: If ffmpeg is not installed. """ - if not _check_ffmpeg_installed(): - raise RuntimeError( - "ffmpeg is required for audio processing but was not found on PATH. " - "Install it via: apt install ffmpeg / brew install ffmpeg / " - "https://ffmpeg.org/download.html" - ) - - with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav: - output_path = temp_wav.name - subprocess.run( - [ - "ffmpeg", - "-i", - audio_path, - "-ar", - str(self._DEFAULT_SAMPLE_RATE), - "-ac", - str(self._DEFAULT_CHANNELS), - "-acodec", - "pcm_s16le", # 16-bit PCM - output_path, - "-y", - ], - check=True, - capture_output=True, + return _audio_to_wav( + audio_path, + sample_rate=self._DEFAULT_SAMPLE_RATE, + channels=self._DEFAULT_CHANNELS, ) - return output_path def _extract_audio_from_video(self, video_path: str) -> Optional[str]: """ @@ -224,9 +259,6 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - RuntimeError: If ffmpeg is not installed. """ return AudioTranscriptHelper.extract_audio_from_video(video_path) @@ -241,37 +273,13 @@ def extract_audio_from_video(video_path: str) -> Optional[str]: Returns: str: a path to the extracted audio file (WAV format) or returns None if extraction fails. - - Raises: - RuntimeError: If ffmpeg is not installed. """ - if not _check_ffmpeg_installed(): - raise RuntimeError( - "ffmpeg is required for audio processing but was not found on PATH. " - "Install it via: apt install ffmpeg / brew install ffmpeg / " - "https://ffmpeg.org/download.html" - ) - try: logger.info(f"Extracting audio from video: {video_path}") - with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio: - output_path = temp_audio.name - subprocess.run( - [ - "ffmpeg", - "-i", - video_path, - "-ar", - str(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE), - "-ac", - str(AudioTranscriptHelper._DEFAULT_CHANNELS), - "-acodec", - "pcm_s16le", # 16-bit PCM - output_path, - "-y", - ], - check=True, - capture_output=True, + output_path = _audio_to_wav( + video_path, + sample_rate=AudioTranscriptHelper._DEFAULT_SAMPLE_RATE, + channels=AudioTranscriptHelper._DEFAULT_CHANNELS, ) logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)") return output_path From eea930223f93d2e009781fb1cc1a5fd0bf33d50c Mon Sep 17 00:00:00 2001 From: Bolor Date: Thu, 12 Mar 2026 09:16:42 -0700 Subject: [PATCH 5/7] adding av to dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 5e3d4ddced..5139ff444c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "aiofiles>=24,<25", "appdirs>=1.4.0", "art>=6.5.0", + "av>=14.0.0", "azure-core>=1.38.0", "azure-identity>=1.19.0", "azure-ai-contentsafety>=1.0.0", From fc3db889fd06d6f3275a26bd933d9b2ea229a1c9 Mon Sep 17 00:00:00 2001 From: Bolor Date: Thu, 12 Mar 2026 11:29:08 -0700 Subject: [PATCH 6/7] adding unit tests --- tests/unit/score/test_audio_scorer.py | 95 +++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/tests/unit/score/test_audio_scorer.py b/tests/unit/score/test_audio_scorer.py index 543323aa3d..4fdd7ce485 100644 --- a/tests/unit/score/test_audio_scorer.py +++ b/tests/unit/score/test_audio_scorer.py @@ -227,3 +227,98 @@ async def test_score_piece_empty_transcript(self, audio_message_piece): # Empty transcript returns empty list assert len(scores) == 0 + + +class TestPyAVAudioConversion: + """Tests for PyAV audio conversion functions""" + + @pytest.fixture + def compliant_wav_file(self): + """Create a compliant 16kHz mono PCM WAV file""" + import av + import numpy as np + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + sample_rate = 16000 + duration = 0.5 + t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32) + audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + with av.open(output_path, "w", format="wav") as container: + stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono") + frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono") + frame.rate = sample_rate + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(None): + container.mux(packet) + + yield output_path + if os.path.exists(output_path): + os.remove(output_path) + + @pytest.fixture + def non_compliant_wav_file(self): + """Create a 44100Hz mono WAV (wrong sample rate)""" + import av + import numpy as np + + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + output_path = tmp.name + + sample_rate = 44100 # Wrong sample rate + duration = 0.5 + t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32) + audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + + with av.open(output_path, "w", format="wav") as container: + stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono") + frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono") + frame.rate = sample_rate + for packet in stream.encode(frame): + container.mux(packet) + for packet in stream.encode(None): + container.mux(packet) + + yield output_path + if os.path.exists(output_path): + os.remove(output_path) + + def test_is_compliant_wav_true(self, compliant_wav_file): + """Test that _is_compliant_wav returns True for compliant files""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav(compliant_wav_file, sample_rate=16000, channels=1) is True + + def test_is_compliant_wav_false_wrong_rate(self, non_compliant_wav_file): + """Test that _is_compliant_wav returns False for wrong sample rate""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav(non_compliant_wav_file, sample_rate=16000, channels=1) is False + + def test_is_compliant_wav_nonexistent_file(self): + """Test that _is_compliant_wav returns False for nonexistent files""" + from pyrit.score.audio_transcript_scorer import _is_compliant_wav + + assert _is_compliant_wav("/nonexistent/file.wav", sample_rate=16000, channels=1) is False + + def test_audio_to_wav_returns_original_for_compliant(self, compliant_wav_file): + """Test that _audio_to_wav returns the original path for compliant files""" + from pyrit.score.audio_transcript_scorer import _audio_to_wav + + result = _audio_to_wav(compliant_wav_file, sample_rate=16000, channels=1) + assert result == compliant_wav_file + + def test_audio_to_wav_converts_non_compliant(self, non_compliant_wav_file): + """Test that _audio_to_wav converts non-compliant files""" + from pyrit.score.audio_transcript_scorer import _audio_to_wav, _is_compliant_wav + + result = _audio_to_wav(non_compliant_wav_file, sample_rate=16000, channels=1) + try: + assert result != non_compliant_wav_file + assert _is_compliant_wav(result, sample_rate=16000, channels=1) is True + finally: + if result != non_compliant_wav_file and os.path.exists(result): + os.remove(result) From 8025d5cccfc271568d769aa3cbbd05e7961146d0 Mon Sep 17 00:00:00 2001 From: Bolor Date: Thu, 12 Mar 2026 12:09:06 -0700 Subject: [PATCH 7/7] addressing comments --- pyrit/score/audio_transcript_scorer.py | 6 +++++- pyrit/score/float_scale/audio_float_scale_scorer.py | 8 +++++++- pyrit/score/true_false/audio_true_false_scorer.py | 8 +++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pyrit/score/audio_transcript_scorer.py b/pyrit/score/audio_transcript_scorer.py index 5d28cfb911..1395e3b968 100644 --- a/pyrit/score/audio_transcript_scorer.py +++ b/pyrit/score/audio_transcript_scorer.py @@ -106,6 +106,7 @@ def __init__( self, *, text_capable_scorer: Scorer, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the base audio scorer. @@ -113,12 +114,15 @@ def __init__( Args: text_capable_scorer (Scorer): A scorer capable of processing text that will be used to score the transcribed audio content. + use_entra_auth (bool, Optional): Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ self._validate_text_scorer(text_capable_scorer) self.text_scorer = text_capable_scorer + self._use_entra_auth = use_entra_auth if use_entra_auth is not None else True @staticmethod def _validate_text_scorer(scorer: Scorer) -> None: @@ -220,7 +224,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str: logger.info(f"Audio transcription: WAV file size = {file_size} bytes") try: - converter = AzureSpeechAudioToTextConverter(use_entra_auth=True) + converter = AzureSpeechAudioToTextConverter(use_entra_auth=self._use_entra_auth) logger.info("Audio transcription: Starting Azure Speech transcription...") result = await converter.convert_async(prompt=wav_path, input_type="audio_path") logger.info(f"Audio transcription: Result = '{result.output_text}'") diff --git a/pyrit/score/float_scale/audio_float_scale_scorer.py b/pyrit/score/float_scale/audio_float_scale_scorer.py index a44988d5a1..528784109b 100644 --- a/pyrit/score/float_scale/audio_float_scale_scorer.py +++ b/pyrit/score/float_scale/audio_float_scale_scorer.py @@ -25,6 +25,7 @@ def __init__( *, text_capable_scorer: FloatScaleScorer, validator: Optional[ScorerPromptValidator] = None, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioFloatScaleScorer. @@ -33,12 +34,17 @@ def __init__( text_capable_scorer: A FloatScaleScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. + use_entra_auth: Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ super().__init__(validator=validator or self._default_validator) - self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer) + self._audio_helper = AudioTranscriptHelper( + text_capable_scorer=text_capable_scorer, + use_entra_auth=use_entra_auth, + ) def _build_identifier(self) -> ComponentIdentifier: """ diff --git a/pyrit/score/true_false/audio_true_false_scorer.py b/pyrit/score/true_false/audio_true_false_scorer.py index 1c7a5de17d..0148cc6046 100644 --- a/pyrit/score/true_false/audio_true_false_scorer.py +++ b/pyrit/score/true_false/audio_true_false_scorer.py @@ -25,6 +25,7 @@ def __init__( *, text_capable_scorer: TrueFalseScorer, validator: Optional[ScorerPromptValidator] = None, + use_entra_auth: Optional[bool] = None, ) -> None: """ Initialize the AudioTrueFalseScorer. @@ -33,12 +34,17 @@ def __init__( text_capable_scorer: A TrueFalseScorer capable of processing text. This scorer will be used to evaluate the transcribed audio content. validator: Validator for the scorer. Defaults to audio_path data type validator. + use_entra_auth: Whether to use Entra ID authentication for Azure Speech. + Defaults to True if None. Raises: ValueError: If text_capable_scorer does not support text data type. """ super().__init__(validator=validator or self._DEFAULT_VALIDATOR) - self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer) + self._audio_helper = AudioTranscriptHelper( + text_capable_scorer=text_capable_scorer, + use_entra_auth=use_entra_auth, + ) def _build_identifier(self) -> ComponentIdentifier: """