Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"aiofiles>=24,<25",
"appdirs>=1.4.0",
"art>=6.5.0",
"av>=14.0.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you work with Hannah and Foundry to see if this should go into default or just extras?

"azure-core>=1.38.0",
"azure-identity>=1.19.0",
"azure-ai-contentsafety>=1.0.0",
Expand Down Expand Up @@ -135,6 +136,7 @@ speech = [
# all includes all functional dependencies excluding the ones from the "dev" extra
all = [
"accelerate>=1.7.0",
"av>=14.0.0",
"azure-ai-ml>=1.27.1",
"azure-cognitiveservices-speech>=1.44.0",
"azureml-mlflow>=1.60.0",
Expand Down
156 changes: 87 additions & 69 deletions pyrit/score/audio_transcript_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from abc import ABC
from typing import Optional

import av

from pyrit.memory import CentralMemory
from pyrit.models import MessagePiece, Score
from pyrit.prompt_converter import AzureSpeechAudioToTextConverter
Expand All @@ -16,6 +18,76 @@
logger = logging.getLogger(__name__)


def _is_compliant_wav(input_path: str, *, sample_rate: int, channels: int) -> bool:
"""
Check if the audio file is already a compliant WAV with the target format.

Args:
input_path (str): Path to the audio file.
sample_rate (int): Expected sample rate in Hz.
channels (int): Expected number of channels.

Returns:
bool: True if the file is already compliant, False otherwise.
"""
try:
with av.open(input_path) as container:
if not container.streams.audio:
return False
stream = container.streams.audio[0]
codec_name = stream.codec_context.name
is_pcm_s16 = codec_name == "pcm_s16le"
is_correct_rate = stream.rate == sample_rate
is_correct_channels = stream.channels == channels
return is_pcm_s16 and is_correct_rate and is_correct_channels
except Exception:
return False


def _audio_to_wav(input_path: str, *, sample_rate: int, channels: int) -> str:
"""
Convert any audio or video file to a normalised PCM WAV using PyAV.

If the input is already a compliant WAV (correct sample rate, channels, and codec),
returns the original path without re-encoding.

Args:
input_path (str): Source audio or video file.
sample_rate (int): Target sample rate in Hz.
channels (int): Target number of channels (1 = mono).

Returns:
str: Path to the WAV file (original if compliant, otherwise a temporary file).
"""
# Skip conversion if already compliant
if _is_compliant_wav(input_path, sample_rate=sample_rate, channels=channels):
logger.debug(f"Audio file already compliant, skipping conversion: {input_path}")
return input_path

layout = "mono" if channels == 1 else "stereo"
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name

with av.open(input_path) as in_container:
with av.open(output_path, "w", format="wav") as out_container:
out_stream = out_container.add_stream("pcm_s16le", rate=sample_rate, layout=layout)
resampler = av.AudioResampler(format="s16", layout=layout, rate=sample_rate)

for frame in in_container.decode(audio=0):
for out_frame in resampler.resample(frame):
for packet in out_stream.encode(out_frame):
out_container.mux(packet)

for out_frame in resampler.resample(None):
for packet in out_stream.encode(out_frame):
out_container.mux(packet)

for packet in out_stream.encode(None):
out_container.mux(packet)

return output_path


class AudioTranscriptHelper(ABC): # noqa: B024
"""
Abstract base class for audio scorers that process audio by transcribing and scoring the text.
Expand All @@ -29,25 +101,28 @@ class AudioTranscriptHelper(ABC): # noqa: B024
_DEFAULT_SAMPLE_RATE = 16000 # 16kHz - Azure Speech optimal rate
_DEFAULT_CHANNELS = 1 # Mono - Azure Speech prefers mono
_DEFAULT_SAMPLE_WIDTH = 2 # 16-bit audio (2 bytes per sample)
_DEFAULT_EXPORT_PARAMS = ["-acodec", "pcm_s16le"] # 16-bit PCM for best compatibility

def __init__(
self,
*,
text_capable_scorer: Scorer,
use_entra_auth: Optional[bool] = None,
) -> None:
"""
Initialize the base audio scorer.

Args:
text_capable_scorer (Scorer): A scorer capable of processing text that will be used to score
the transcribed audio content.
use_entra_auth (bool, Optional): Whether to use Entra ID authentication for Azure Speech.
Defaults to True if None.

Raises:
ValueError: If text_capable_scorer does not support text data type.
"""
self._validate_text_scorer(text_capable_scorer)
self.text_scorer = text_capable_scorer
self._use_entra_auth = use_entra_auth if use_entra_auth is not None else True

@staticmethod
def _validate_text_scorer(scorer: Scorer) -> None:
Expand Down Expand Up @@ -149,7 +224,7 @@ async def _transcribe_audio_async(self, audio_path: str) -> str:
logger.info(f"Audio transcription: WAV file size = {file_size} bytes")

try:
converter = AzureSpeechAudioToTextConverter()
converter = AzureSpeechAudioToTextConverter(use_entra_auth=self._use_entra_auth)
logger.info("Audio transcription: Starting Azure Speech transcription...")
result = await converter.convert_async(prompt=wav_path, input_type="audio_path")
logger.info(f"Audio transcription: Result = '{result.output_text}'")
Expand All @@ -171,25 +246,12 @@ def _ensure_wav_format(self, audio_path: str) -> str:

Returns:
str: Path to WAV file (original if already WAV, or converted temporary file).

Raises:
ModuleNotFoundError: If pydub is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e

audio = AudioSegment.from_file(audio_path)
audio = (
audio.set_frame_rate(self._DEFAULT_SAMPLE_RATE)
.set_channels(self._DEFAULT_CHANNELS)
.set_sample_width(self._DEFAULT_SAMPLE_WIDTH)
return _audio_to_wav(
audio_path,
sample_rate=self._DEFAULT_SAMPLE_RATE,
channels=self._DEFAULT_CHANNELS,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
audio.export(temp_wav.name, format="wav")
return temp_wav.name

def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
"""
Expand All @@ -201,9 +263,6 @@ def _extract_audio_from_video(self, video_path: str) -> Optional[str]:
Returns:
str: a path to the extracted audio file (WAV format)
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
"""
return AudioTranscriptHelper.extract_audio_from_video(video_path)

Expand All @@ -218,57 +277,16 @@ def extract_audio_from_video(video_path: str) -> Optional[str]:
Returns:
str: a path to the extracted audio file (WAV format)
or returns None if extraction fails.

Raises:
ModuleNotFoundError: If pydub/ffmpeg is not installed.
"""
try:
from pydub import AudioSegment
except ModuleNotFoundError as e:
logger.error("Could not import pydub. Install it via 'pip install pydub'")
raise e

try:
# Extract audio from video using pydub (requires ffmpeg)
logger.info(f"Extracting audio from video: {video_path}")
audio = AudioSegment.from_file(video_path)
logger.info(
f"Audio extracted: duration={len(audio)}ms, channels={audio.channels}, "
f"sample_width={audio.sample_width}, frame_rate={audio.frame_rate}"
output_path = _audio_to_wav(
video_path,
sample_rate=AudioTranscriptHelper._DEFAULT_SAMPLE_RATE,
channels=AudioTranscriptHelper._DEFAULT_CHANNELS,
)

# Optimize for Azure Speech recognition:
# Azure Speech works best with 16kHz mono audio (same as Azure TTS output)
if audio.frame_rate != AudioTranscriptHelper._DEFAULT_SAMPLE_RATE:
logger.info(
f"Resampling audio from {audio.frame_rate}Hz to {AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz"
)
audio = audio.set_frame_rate(AudioTranscriptHelper._DEFAULT_SAMPLE_RATE)

# Ensure 16-bit audio
if audio.sample_width != AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH:
logger.info(
f"Converting sample width from {audio.sample_width * 8}-bit"
f" to {AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH * 8}-bit"
)
audio = audio.set_sample_width(AudioTranscriptHelper._DEFAULT_SAMPLE_WIDTH)

# Convert to mono (Azure Speech prefers mono)
if audio.channels > AudioTranscriptHelper._DEFAULT_CHANNELS:
logger.info(f"Converting from {audio.channels} channels to mono")
audio = audio.set_channels(AudioTranscriptHelper._DEFAULT_CHANNELS)

# Create temporary WAV file with PCM encoding for best compatibility
with tempfile.NamedTemporaryFile(suffix="_video_audio.wav", delete=False) as temp_audio:
audio.export(
temp_audio.name,
format="wav",
parameters=AudioTranscriptHelper._DEFAULT_EXPORT_PARAMS,
)
logger.info(
f"Audio exported to: {temp_audio.name} (duration={len(audio)}ms, rate={audio.frame_rate}Hz, mono)"
)
return temp_audio.name
logger.info(f"Audio exported to: {output_path} (rate={AudioTranscriptHelper._DEFAULT_SAMPLE_RATE}Hz, mono)")
return output_path
except Exception as e:
logger.warning(f"Failed to extract audio from video {video_path}: {e}")
return None
8 changes: 7 additions & 1 deletion pyrit/score/float_scale/audio_float_scale_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
*,
text_capable_scorer: FloatScaleScorer,
validator: Optional[ScorerPromptValidator] = None,
use_entra_auth: Optional[bool] = None,
) -> None:
"""
Initialize the AudioFloatScaleScorer.
Expand All @@ -33,12 +34,17 @@ def __init__(
text_capable_scorer: A FloatScaleScorer capable of processing text.
This scorer will be used to evaluate the transcribed audio content.
validator: Validator for the scorer. Defaults to audio_path data type validator.
use_entra_auth: Whether to use Entra ID authentication for Azure Speech.
Defaults to True if None.

Raises:
ValueError: If text_capable_scorer does not support text data type.
"""
super().__init__(validator=validator or self._default_validator)
self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer)
self._audio_helper = AudioTranscriptHelper(
text_capable_scorer=text_capable_scorer,
use_entra_auth=use_entra_auth,
)

def _build_identifier(self) -> ComponentIdentifier:
"""
Expand Down
8 changes: 7 additions & 1 deletion pyrit/score/true_false/audio_true_false_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
*,
text_capable_scorer: TrueFalseScorer,
validator: Optional[ScorerPromptValidator] = None,
use_entra_auth: Optional[bool] = None,
) -> None:
"""
Initialize the AudioTrueFalseScorer.
Expand All @@ -33,12 +34,17 @@ def __init__(
text_capable_scorer: A TrueFalseScorer capable of processing text.
This scorer will be used to evaluate the transcribed audio content.
validator: Validator for the scorer. Defaults to audio_path data type validator.
use_entra_auth: Whether to use Entra ID authentication for Azure Speech.
Defaults to True if None.

Raises:
ValueError: If text_capable_scorer does not support text data type.
"""
super().__init__(validator=validator or self._DEFAULT_VALIDATOR)
self._audio_helper = AudioTranscriptHelper(text_capable_scorer=text_capable_scorer)
self._audio_helper = AudioTranscriptHelper(
text_capable_scorer=text_capable_scorer,
use_entra_auth=use_entra_auth,
)

def _build_identifier(self) -> ComponentIdentifier:
"""
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/score/test_audio_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,98 @@ async def test_score_piece_empty_transcript(self, audio_message_piece):

# Empty transcript returns empty list
assert len(scores) == 0


class TestPyAVAudioConversion:
"""Tests for PyAV audio conversion functions"""

@pytest.fixture
def compliant_wav_file(self):
"""Create a compliant 16kHz mono PCM WAV file"""
import av
import numpy as np

with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name

sample_rate = 16000
duration = 0.5
t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32)
audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16)

with av.open(output_path, "w", format="wav") as container:
stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono")
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono")
frame.rate = sample_rate
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)

yield output_path
if os.path.exists(output_path):
os.remove(output_path)

@pytest.fixture
def non_compliant_wav_file(self):
"""Create a 44100Hz mono WAV (wrong sample rate)"""
import av
import numpy as np

with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
output_path = tmp.name

sample_rate = 44100 # Wrong sample rate
duration = 0.5
t = np.linspace(0, duration, int(sample_rate * duration), dtype=np.float32)
audio_data = (np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16)

with av.open(output_path, "w", format="wav") as container:
stream = container.add_stream("pcm_s16le", rate=sample_rate, layout="mono")
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format="s16", layout="mono")
frame.rate = sample_rate
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)

yield output_path
if os.path.exists(output_path):
os.remove(output_path)

def test_is_compliant_wav_true(self, compliant_wav_file):
"""Test that _is_compliant_wav returns True for compliant files"""
from pyrit.score.audio_transcript_scorer import _is_compliant_wav

assert _is_compliant_wav(compliant_wav_file, sample_rate=16000, channels=1) is True

def test_is_compliant_wav_false_wrong_rate(self, non_compliant_wav_file):
"""Test that _is_compliant_wav returns False for wrong sample rate"""
from pyrit.score.audio_transcript_scorer import _is_compliant_wav

assert _is_compliant_wav(non_compliant_wav_file, sample_rate=16000, channels=1) is False

def test_is_compliant_wav_nonexistent_file(self):
"""Test that _is_compliant_wav returns False for nonexistent files"""
from pyrit.score.audio_transcript_scorer import _is_compliant_wav

assert _is_compliant_wav("/nonexistent/file.wav", sample_rate=16000, channels=1) is False

def test_audio_to_wav_returns_original_for_compliant(self, compliant_wav_file):
"""Test that _audio_to_wav returns the original path for compliant files"""
from pyrit.score.audio_transcript_scorer import _audio_to_wav

result = _audio_to_wav(compliant_wav_file, sample_rate=16000, channels=1)
assert result == compliant_wav_file

def test_audio_to_wav_converts_non_compliant(self, non_compliant_wav_file):
"""Test that _audio_to_wav converts non-compliant files"""
from pyrit.score.audio_transcript_scorer import _audio_to_wav, _is_compliant_wav

result = _audio_to_wav(non_compliant_wav_file, sample_rate=16000, channels=1)
try:
assert result != non_compliant_wav_file
assert _is_compliant_wav(result, sample_rate=16000, channels=1) is True
finally:
if result != non_compliant_wav_file and os.path.exists(result):
os.remove(result)
Loading