diff --git a/assets/photosynthesis_question.wav b/assets/photosynthesis_question.wav new file mode 100644 index 0000000000..44f9e54141 Binary files /dev/null and b/assets/photosynthesis_question.wav differ diff --git a/doc/code/executor/attack/barge_in_attack.ipynb b/doc/code/executor/attack/barge_in_attack.ipynb new file mode 100644 index 0000000000..a891a47f85 --- /dev/null +++ b/doc/code/executor/attack/barge_in_attack.ipynb @@ -0,0 +1,398 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Barge-In Attack (Streaming Audio)\n", + "\n", + "`BargeInAttack` streams user audio to a `RealtimeTarget` and uses server-side voice-activity\n", + "detection (VAD) to detect turn boundaries. When the user speaks while the assistant is still\n", + "responding, server VAD cancels the in-flight response (barge-in). Interrupted turns are\n", + "persisted with `prompt_metadata[\"interrupted\"] = True`.\n", + "\n", + "Audio converters are applied per turn after VAD commits. The raw audio drives interruption\n", + "timing while the model responds to the converted version.\n", + "\n", + "> **Note:** Memory must be initialized via `initialize_pyrit_async`. See the\n", + "> [Memory Configuration Guide](../../memory/0_memory.md)." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "`BargeInAttack` requires a `RealtimeTarget` with `server_vad=True` (or a `ServerVadConfig`\n", + "for custom tuning)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['./.pyrit/.env']\n", + "Loaded environment file: ./.pyrit/.env\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No new upgrade operations detected.\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import wave\n", + "from pathlib import Path\n", + "\n", + "from pyrit.executor.attack import (\n", + " AttackConverterConfig,\n", + " BargeInAttack,\n", + " BargeInAttackContext,\n", + " ConsoleAttackResultPrinter,\n", + ")\n", + "from pyrit.executor.attack.core import AttackParameters\n", + "from pyrit.memory import CentralMemory\n", + "from pyrit.prompt_converter import AudioFrequencyConverter\n", + "from pyrit.prompt_normalizer import PromptConverterConfiguration\n", + "from pyrit.prompt_target import RealtimeTarget\n", + "from pyrit.setup import IN_MEMORY, initialize_pyrit_async\n", + "\n", + "await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Shared setup\n", + "\n", + "Both sections use a pre-recorded 24 kHz mono PCM16 question about photosynthesis. The\n", + "format matches what the OpenAI Realtime API expects. Any async generator yielding 24 kHz\n", + "PCM16 bytes works as a chunk source (live mic, TTS, etc.)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded question: 3.94s @ 24 kHz\n" + ] + } + ], + "source": [ + "CHUNK_MS = 100\n", + "CHUNK_SIZE = CHUNK_MS * 48 # PCM16 @ 24 kHz mono = 48 bytes per millisecond.\n", + "SILENCE_CHUNK = b\"\\x00\" * CHUNK_SIZE\n", + "audio_path = Path(\"../../../../assets/photosynthesis_question.wav\").resolve()\n", + "\n", + "\n", + "def _load_pcm(path: Path) -> bytes:\n", + " \"\"\"Read a WAV at 24 kHz / mono / PCM16 into raw PCM bytes.\"\"\"\n", + " with wave.open(str(path), \"rb\") as wav:\n", + " assert wav.getframerate() == 24000 and wav.getnchannels() == 1 and wav.getsampwidth() == 2\n", + " return wav.readframes(wav.getnframes())\n", + "\n", + "\n", + "async def _yield_chunks(pcm: bytes, real_time: bool = True):\n", + " \"\"\"Yield PCM in 100ms slices, optionally pacing at real-time.\"\"\"\n", + " for offset in range(0, len(pcm), CHUNK_SIZE):\n", + " yield pcm[offset : offset + CHUNK_SIZE]\n", + " if real_time:\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "question_pcm_24k = _load_pcm(audio_path)\n", + "print(f\"Loaded question: {len(question_pcm_24k) / 48 / 1000:.2f}s @ 24 kHz\")\n", + "\n", + "converters = PromptConverterConfiguration.from_converters(converters=[AudioFrequencyConverter(shift_value=200)])" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## Section 1: Single-turn streaming with a converter\n", + "\n", + "Streams one user statement, applies a frequency-shift converter after VAD commits the turn,\n", + "and gets the model's response. Exercises the full pipeline (chunk push, convert-on-commit,\n", + "item swap, response trigger, memory persistence) without barge-in." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "executed_turns: 1\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294332341158.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Sure! Photosynthesis is the process plants use to convert light energy into chemical energy, which they store as sugars. It mainly takes place in the chloroplasts of leaf cells. Here's how it works:\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 1. Light absorption: Chlorophyll, the green pigment, captures sunlight. This energy excites electrons within the chlorophyll.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 2. Water splitting: The plant takes in water (H₂O) from the roots and transfers it to the leaves. The light energy splits the water molecules into oxygen, protons, and electrons. The oxygen is\u001b[0m\n", + "\u001b[33m released as a byproduct.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 3. Conversion of energy: The excited electrons move through a chain of proteins, creating ATP and NADPH, which are energy carriers.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 4. Carbon fixation: Using that stored energy, the plant takes in carbon dioxide (CO₂) from the air. Through the Calvin cycle, it combines the CO₂ with the energy carriers to form glucose.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m This glucose feeds the plant and can be stored as starch. In essence, photosynthesis fuels plant growth and provides oxygen for us.\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294332344158.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" + ] + } + ], + "source": [ + "async def single_turn_source():\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " # Trailing silence helps server VAD recognize end-of-turn.\n", + " for _ in range(25): # 2.5s trailing silence, above the 1.5s VAD threshold\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "target = RealtimeTarget(server_vad=True)\n", + "attack = BargeInAttack(\n", + " objective_target=target,\n", + " attack_converter_config=AttackConverterConfig(request_converters=converters),\n", + ")\n", + "\n", + "context = BargeInAttackContext(\n", + " params=AttackParameters(objective=\"Observe a single converted user turn end-to-end\"),\n", + " audio_chunks=single_turn_source(),\n", + ")\n", + "\n", + "result = await attack.execute_with_context_async(context=context) # type: ignore\n", + "print(f\"executed_turns: {result.executed_turns}\")\n", + "await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=result) # type: ignore\n", + "await target.cleanup_target() # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Section 2: Barge-in (interrupting the assistant mid-response)\n", + "\n", + "Plays the question twice with timing arranged so turn 2's speech arrives during turn 1's\n", + "response. Server VAD detects the new speech, cancels turn 1's response, and resolves it\n", + "with `interrupted=True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "executed_turns: 2\n", + "\n", + "Persisted pieces (4 messages):\n", + " user audio_path: 1779294342848770.mp3\n", + " assistant text [INTERRUPTED]: Sure! Photosynthesis is the process plants use to convert light energy into chem...\n", + " assistant audio_path [INTERRUPTED]: 1779294342850774.mp3\n", + " user audio_path: 1779294366566679.mp3\n", + " assistant text: Absolutely! Let’s break it down step by step.\n", + "\n", + "1. **Where it happens**: Photosyn...\n", + " assistant audio_path: 1779294366569687.mp3\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294342848770.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Sure! Photosynthesis is the process plants use to convert light energy into chemical energy they can use as\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294342850774.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 2 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294366566679.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Absolutely! Let’s break it down step by step.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 1. **Where it happens**: Photosynthesis takes place in chloroplasts, which are specialized structures inside plant cells. These contain chlorophyll, the green pigment that captures light energy from\u001b[0m\n", + "\u001b[33m the sun.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 2. **The raw materials**: Plants use carbon dioxide from the air (taken in through tiny pores called stomata) and water from the soil (absorbed through their roots).\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 3. **The light-dependent reactions**: Inside the chloroplasts, chlorophyll absorbs sunlight, which excites electrons. This energy splits water molecules into oxygen, protons, and electrons. Oxygen\u001b[0m\n", + "\u001b[33m is released as a byproduct (that’s the oxygen we breathe!). The electrons and protons help generate energy-rich molecules called ATP and NADPH.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 4. **The Calvin cycle (light-independent reactions)**: Using the ATP and NADPH, plants convert carbon dioxide into glucose through a series of enzyme-driven steps. Glucose is a simple sugar that\u001b[0m\n", + "\u001b[33m plants use to build more complex carbohydrates like starch and cellulose, fueling growth and development.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 5. **Energy storage and use**: The glucose can be used immediately for energy, or it can be stored as starch. This stored energy supports the plant’s metabolism, growth, and reproduction.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m In short, plants take in sunlight, water, and carbon dioxide, and through photosynthesis they produce oxygen and energy-rich sugars that sustain both themselves and, ultimately, life on Earth.\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294366569687.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" + ] + } + ], + "source": [ + "TURN1_RESPONSE_WAIT_S = 0.2 # how long to let the model start speaking before barging in\n", + "\n", + "\n", + "async def barge_in_source():\n", + " # Turn 1: speak the question, then 1.5s of silence so VAD commits.\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " for _ in range(25): # 2.5s trailing silence\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + " # Let the model get partway into its response before we interrupt.\n", + " for _ in range(int(TURN1_RESPONSE_WAIT_S * 10)):\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + " # Turn 2: speak the question again. VAD's speech_started fires while turn 1's response\n", + " # is still streaming → server cancels + truncates turn 1.\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " for _ in range(25): # 2.5s trailing silence\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "target2 = RealtimeTarget(server_vad=True)\n", + "attack2 = BargeInAttack(\n", + " objective_target=target2,\n", + " attack_converter_config=AttackConverterConfig(request_converters=converters),\n", + ")\n", + "\n", + "barge_in_context = BargeInAttackContext(\n", + " params=AttackParameters(objective=\"Demonstrate barge-in by interrupting a benign answer\"),\n", + " audio_chunks=barge_in_source(),\n", + ")\n", + "\n", + "barge_in_result = await attack2.execute_with_context_async(context=barge_in_context) # type: ignore\n", + "print(f\"executed_turns: {barge_in_result.executed_turns}\")\n", + "\n", + "# Inspect memory to verify the barge-in landed in metadata.\n", + "memory = CentralMemory.get_memory_instance()\n", + "turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id)\n", + "print(f\"\\nPersisted pieces ({len(turns)} messages):\")\n", + "for message in turns:\n", + " for piece in message.message_pieces:\n", + " interrupted = piece.prompt_metadata.get(\"interrupted\")\n", + " marker = \" [INTERRUPTED]\" if interrupted else \"\"\n", + " val = piece.converted_value\n", + " if piece.converted_value_data_type == \"audio_path\":\n", + " val = Path(val).name\n", + " value_preview = (val[:80] + \"...\") if len(val) > 80 else val\n", + " print(f\" {piece._role} {piece.converted_value_data_type}{marker}: {value_preview}\")\n", + "\n", + "await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=barge_in_result) # type: ignore\n", + "await target2.cleanup_target() # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "### Reading the barge-in output\n", + "\n", + "If barge-in fired successfully:\n", + "- `executed_turns: 2` (two VAD-detected user turns)\n", + "- First assistant turn shows `[INTERRUPTED]` with a truncated transcript\n", + "- Second assistant turn completes normally\n", + "\n", + "If you don't see `[INTERRUPTED]`, decrease `TURN1_RESPONSE_WAIT_S` so turn 2's audio\n", + "arrives earlier in turn 1's response window." + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "## Alternate chunk sources\n", + "\n", + "The chunk source is the main strategy hook:\n", + "\n", + "- **Pre-recorded WAV** (this notebook): most common starting point\n", + "- **TTS converter**: generate audio from text prompts dynamically\n", + "- **Live microphone**: use `sounddevice` or similar; yield what the mic produces\n", + "\n", + "For adaptive attacks (e.g., score-driven strategies), subclass `BargeInAttack` and override\n", + "`_perform_async` to interleave turn observation with chunk generation." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/executor/attack/barge_in_attack.py b/doc/code/executor/attack/barge_in_attack.py new file mode 100644 index 0000000000..d96899e910 --- /dev/null +++ b/doc/code/executor/attack/barge_in_attack.py @@ -0,0 +1,206 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# --- + +# %% [markdown] +# # Barge-In Attack (Streaming Audio) +# +# `BargeInAttack` streams user audio to a `RealtimeTarget` and uses server-side voice-activity +# detection (VAD) to detect turn boundaries. When the user speaks while the assistant is still +# responding, server VAD cancels the in-flight response (barge-in). Interrupted turns are +# persisted with `prompt_metadata["interrupted"] = True`. +# +# Audio converters are applied per turn after VAD commits. The raw audio drives interruption +# timing while the model responds to the converted version. +# +# > **Note:** Memory must be initialized via `initialize_pyrit_async`. See the +# > [Memory Configuration Guide](../../memory/0_memory.md). + +# %% [markdown] +# ## Setup +# +# `BargeInAttack` requires a `RealtimeTarget` with `server_vad=True` (or a `ServerVadConfig` +# for custom tuning). + +# %% +import asyncio +import wave +from pathlib import Path + +from pyrit.executor.attack import ( + AttackConverterConfig, + BargeInAttack, + BargeInAttackContext, + ConsoleAttackResultPrinter, +) +from pyrit.executor.attack.core import AttackParameters +from pyrit.memory import CentralMemory +from pyrit.prompt_converter import AudioFrequencyConverter +from pyrit.prompt_normalizer import PromptConverterConfiguration +from pyrit.prompt_target import RealtimeTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + +await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore + +# %% [markdown] +# ## Shared setup +# +# Both sections use a pre-recorded 24 kHz mono PCM16 question about photosynthesis. The +# format matches what the OpenAI Realtime API expects. Any async generator yielding 24 kHz +# PCM16 bytes works as a chunk source (live mic, TTS, etc.). + +# %% +CHUNK_MS = 100 +CHUNK_SIZE = CHUNK_MS * 48 # PCM16 @ 24 kHz mono = 48 bytes per millisecond. +SILENCE_CHUNK = b"\x00" * CHUNK_SIZE +audio_path = Path("../../../../assets/photosynthesis_question.wav").resolve() + + +def _load_pcm(path: Path) -> bytes: + """Read a WAV at 24 kHz / mono / PCM16 into raw PCM bytes.""" + with wave.open(str(path), "rb") as wav: + assert wav.getframerate() == 24000 and wav.getnchannels() == 1 and wav.getsampwidth() == 2 + return wav.readframes(wav.getnframes()) + + +async def _yield_chunks(pcm: bytes, real_time: bool = True): + """Yield PCM in 100ms slices, optionally pacing at real-time.""" + for offset in range(0, len(pcm), CHUNK_SIZE): + yield pcm[offset : offset + CHUNK_SIZE] + if real_time: + await asyncio.sleep(CHUNK_MS / 1000) + + +question_pcm_24k = _load_pcm(audio_path) +print(f"Loaded question: {len(question_pcm_24k) / 48 / 1000:.2f}s @ 24 kHz") + +converters = PromptConverterConfiguration.from_converters(converters=[AudioFrequencyConverter(shift_value=200)]) + + +# %% [markdown] +# ## Section 1: Single-turn streaming with a converter +# +# Streams one user statement, applies a frequency-shift converter after VAD commits the turn, +# and gets the model's response. Exercises the full pipeline (chunk push, convert-on-commit, +# item swap, response trigger, memory persistence) without barge-in. + + +# %% +async def single_turn_source(): + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + # Trailing silence helps server VAD recognize end-of-turn. + for _ in range(25): # 2.5s trailing silence, above the 1.5s VAD threshold + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + +target = RealtimeTarget(server_vad=True) +attack = BargeInAttack( + objective_target=target, + attack_converter_config=AttackConverterConfig(request_converters=converters), +) + +context = BargeInAttackContext( + params=AttackParameters(objective="Observe a single converted user turn end-to-end"), + audio_chunks=single_turn_source(), +) + +result = await attack.execute_with_context_async(context=context) # type: ignore +print(f"executed_turns: {result.executed_turns}") +await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=result) # type: ignore +await target.cleanup_target() # type: ignore + +# %% [markdown] +# ## Section 2: Barge-in (interrupting the assistant mid-response) +# +# Plays the question twice with timing arranged so turn 2's speech arrives during turn 1's +# response. Server VAD detects the new speech, cancels turn 1's response, and resolves it +# with `interrupted=True`. + +# %% +TURN1_RESPONSE_WAIT_S = 0.2 # how long to let the model start speaking before barging in + + +async def barge_in_source(): + # Turn 1: speak the question, then 1.5s of silence so VAD commits. + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + for _ in range(25): # 2.5s trailing silence + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + # Let the model get partway into its response before we interrupt. + for _ in range(int(TURN1_RESPONSE_WAIT_S * 10)): + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + # Turn 2: speak the question again. VAD's speech_started fires while turn 1's response + # is still streaming → server cancels + truncates turn 1. + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + for _ in range(25): # 2.5s trailing silence + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + +target2 = RealtimeTarget(server_vad=True) +attack2 = BargeInAttack( + objective_target=target2, + attack_converter_config=AttackConverterConfig(request_converters=converters), +) + +barge_in_context = BargeInAttackContext( + params=AttackParameters(objective="Demonstrate barge-in by interrupting a benign answer"), + audio_chunks=barge_in_source(), +) + +barge_in_result = await attack2.execute_with_context_async(context=barge_in_context) # type: ignore +print(f"executed_turns: {barge_in_result.executed_turns}") + +# Inspect memory to verify the barge-in landed in metadata. +memory = CentralMemory.get_memory_instance() +turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id) +print(f"\nPersisted pieces ({len(turns)} messages):") +for message in turns: + for piece in message.message_pieces: + interrupted = piece.prompt_metadata.get("interrupted") + marker = " [INTERRUPTED]" if interrupted else "" + val = piece.converted_value + if piece.converted_value_data_type == "audio_path": + val = Path(val).name + value_preview = (val[:80] + "...") if len(val) > 80 else val + print(f" {piece._role} {piece.converted_value_data_type}{marker}: {value_preview}") + +await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=barge_in_result) # type: ignore +await target2.cleanup_target() # type: ignore + +# %% [markdown] +# ### Reading the barge-in output +# +# If barge-in fired successfully: +# - `executed_turns: 2` (two VAD-detected user turns) +# - First assistant turn shows `[INTERRUPTED]` with a truncated transcript +# - Second assistant turn completes normally +# +# If you don't see `[INTERRUPTED]`, decrease `TURN1_RESPONSE_WAIT_S` so turn 2's audio +# arrives earlier in turn 1's response window. + +# %% [markdown] +# ## Alternate chunk sources +# +# The chunk source is the main strategy hook: +# +# - **Pre-recorded WAV** (this notebook): most common starting point +# - **TTS converter**: generate audio from text prompts dynamically +# - **Live microphone**: use `sounddevice` or similar; yield what the mic produces +# +# For adaptive attacks (e.g., score-driven strategies), subclass `BargeInAttack` and override +# `_perform_async` to interleave turn observation with chunk generation. diff --git a/doc/myst.yml b/doc/myst.yml index 491d875568..3a66bd3763 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -87,6 +87,7 @@ project: - file: code/executor/attack/role_play_attack.ipynb - file: code/executor/attack/skeleton_key_attack.ipynb - file: code/executor/attack/tap_attack.ipynb + - file: code/executor/attack/barge_in_attack.ipynb - file: code/executor/attack/violent_durian_attack.ipynb - file: code/executor/workflow/0_workflow.md children: diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index e0c4f44fc6..dc1589a1da 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -49,6 +49,7 @@ SingleTurnAttackStrategy, SkeletonKeyAttack, ) +from pyrit.executor.attack.streaming import BargeInAttack, BargeInAttackContext # Backward-compatibility aliases — import from pyrit.output.attack_result directly. # TODO: Remove these re-exports in two releases (target removal: 0.16.0). @@ -96,6 +97,8 @@ "ConversationState", "AttackExecutor", "AttackExecutorResult", + "BargeInAttack", + "BargeInAttackContext", "PrependedConversationConfig", "generate_simulated_conversation_async", ] diff --git a/pyrit/executor/attack/streaming/__init__.py b/pyrit/executor/attack/streaming/__init__.py new file mode 100644 index 0000000000..b743ea7961 --- /dev/null +++ b/pyrit/executor/attack/streaming/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Streaming attack strategies (barge-in over realtime audio targets).""" + +from pyrit.executor.attack.streaming.barge_in import BargeInAttack, BargeInAttackContext + +__all__ = [ + "BargeInAttack", + "BargeInAttackContext", +] diff --git a/pyrit/executor/attack/streaming/barge_in.py b/pyrit/executor/attack/streaming/barge_in.py new file mode 100644 index 0000000000..141289ec39 --- /dev/null +++ b/pyrit/executor/attack/streaming/barge_in.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Streaming barge-in attack over realtime audio targets.""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults +from pyrit.executor.attack.core.attack_config import AttackConverterConfig +from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT +from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy +from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.memory import CentralMemory +from pyrit.models import ( + AttackOutcome, + AttackResult, + Message, + MessagePiece, + construct_response_from_request, +) +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from pyrit.identifiers import ComponentIdentifier + from pyrit.prompt_target import PromptTarget + from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeEventDispatcher, + RealtimeTargetResult, + ) + from pyrit.prompt_target.openai.openai_realtime_target import RealtimeTarget + +logger = logging.getLogger(__name__) + +_REALTIME_SAMPLE_RATE_HZ = 24000 + + +@dataclass +class BargeInAttackContext(AttackContext[AttackParamsT]): + """Context for a streaming barge-in attack with audio chunk source and session config.""" + + conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) + audio_chunks: AsyncIterator[bytes] | None = None + system_prompt: str = "You are a helpful AI assistant" + + +@dataclass +class _BargeInRunState: + """Mutable per-session state accumulated as turns commit.""" + + raw_buffer: bytearray = field(default_factory=bytearray) + turn_lock: asyncio.Lock = field(default_factory=asyncio.Lock) + last_assistant_message: Message | None = None + executed_turns: int = 0 + turn_tasks: list[asyncio.Task[None]] = field(default_factory=list) + + +class BargeInAttack(AttackStrategy["BargeInAttackContext[Any]", AttackResult]): + """ + Streaming attack that drives a Realtime API session with server VAD + barge-in. + + The attack pushes user audio chunks through the target, lets server VAD detect + turn boundaries, manually fires ``response.create`` after each commit, and + observes assistant turns (including interrupted ones) via per-turn futures + returned by the target's ``request_response_async``. + """ + + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements( + required=frozenset({CapabilityName.STREAMING_BARGE_IN}), + ) + + #: Maximum time to wait after the chunk source exhausts for any in-flight VAD-committed + #: turn to finish (commit → convert → response.create → response.done → persist). Acts as + #: a safety cap; the attack returns as soon as the last turn actually completes. + _MAX_POST_STREAM_WAIT_SECONDS = 30.0 + + @apply_defaults + def __init__( + self, + *, + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] + attack_converter_config: AttackConverterConfig | None = None, + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[ty:invalid-parameter-default] + ) -> None: + """ + Initialize the streaming barge-in attack. + + Args: + objective_target: Target to attack. Must declare ``STREAMING_BARGE_IN`` capability. + Audio normalization is delegated to ``objective_target.audio_normalizer``. + attack_converter_config: Converters applied to each committed user turn. + params_type: Attack parameter dataclass type. + """ + super().__init__( + objective_target=objective_target, + context_type=BargeInAttackContext, + params_type=params_type, + logger=logger, + ) + attack_converter_config = attack_converter_config or AttackConverterConfig() + self._request_converters = attack_converter_config.request_converters + self._response_converters = attack_converter_config.response_converters + + def _validate_context(self, *, context: BargeInAttackContext[Any]) -> None: + """ + Validate the context before executing. + + Args: + context: The streaming attack context. + + Raises: + ValueError: If the context is missing required fields. + """ + if not context.objective or context.objective.isspace(): + raise ValueError("Attack objective must be provided and non-empty in the context") + if context.audio_chunks is None: + raise ValueError("BargeInAttackContext.audio_chunks must be set to an async iterator of PCM bytes") + + async def _setup_async(self, *, context: BargeInAttackContext[Any]) -> None: + """ + Set up the attack: nothing beyond ensuring a conversation id is present. + """ + if not context.conversation_id: + context.conversation_id = str(uuid.uuid4()) + + async def _teardown_async(self, *, context: BargeInAttackContext[Any]) -> None: + """No-op teardown — connection / dispatcher are closed inside ``_perform_async``.""" + return + + async def _perform_async(self, *, context: BargeInAttackContext[Any]) -> AttackResult: + """ + Run the streaming session: connect, subscribe, push chunks, await final turn, tear down. + + Args: + context: Streaming attack context with ``audio_chunks`` source. + + Returns: + An ``AttackResult`` capturing the last assistant turn (if any) and the + number of completed turns. + + Raises: + ValueError: If ``context.audio_chunks`` is ``None``. + """ + target = cast("RealtimeTarget", self._objective_target) + if context.audio_chunks is None: + raise ValueError("BargeInAttackContext.audio_chunks must be set before executing the attack.") + + connection = await target.connect_async(conversation_id=context.conversation_id) + state = _BargeInRunState() + + async def on_committed(event: CommittedEvent) -> None: + current_task = asyncio.current_task() + if current_task is not None: + state.turn_tasks.append(current_task) + try: + await self._handle_committed_turn_async( + state=state, + event=event, + target=target, + connection=connection, + dispatcher=dispatcher, + conversation_id=context.conversation_id, + ) + except Exception: + logger.exception("BargeInAttack turn failed in convert-on-commit handler.") + + dispatcher: RealtimeEventDispatcher = await target.subscribe_events_async( + connection=connection, + on_user_audio_committed=on_committed, + ) + + try: + await target.send_streaming_session_config_async(connection=connection, system_prompt=context.system_prompt) + + async for chunk in context.audio_chunks: + if chunk: + state.raw_buffer.extend(chunk) + await target.push_audio_chunk_async(connection=connection, pcm_bytes=chunk) + + # Wait for any in-flight committed-turn tasks to finish (convert + response + + # persistence), capped by a safety timeout. The chunk source must end with enough + # trailing silence for server VAD's silence threshold to fire commit — otherwise + # the last turn never enters the convert pipeline and there is nothing to wait on. + await self._wait_for_pending_turns_async(state.turn_tasks) + finally: + await dispatcher.stop() + try: + await connection.close() + except Exception as e: + logger.warning(f"Error closing streaming connection: {e}") + + return self._build_result(state=state, context=context) + + async def _handle_committed_turn_async( + self, + *, + state: _BargeInRunState, + event: CommittedEvent, + target: RealtimeTarget, + connection: Any, + dispatcher: RealtimeEventDispatcher, + conversation_id: str, + ) -> None: + """Run the convert-on-commit dance for one VAD-committed user audio turn.""" + async with state.turn_lock: + snapshot = self._snapshot_user_audio(state) + + try: + converted_pcm, applied_identifiers = await target.audio_normalizer.normalize_async( + pcm_bytes=snapshot, + sample_rate=_REALTIME_SAMPLE_RATE_HZ, + converter_configurations=self._request_converters, + ) + except Exception: + logger.exception("Audio converters failed; dropping turn.") + return + + using_converted_audio = bool(self._request_converters) and converted_pcm != snapshot + if using_converted_audio: + await target.swap_user_audio_async( + connection=connection, + committed_event=event, + converted_pcm=converted_pcm, + ) + + turn_future = await target.request_response_async(connection=connection, dispatcher=dispatcher) + turn_result = await turn_future + + user_audio_pcm = converted_pcm if using_converted_audio else snapshot + state.last_assistant_message = await self._persist_turn_async( + target=target, + conversation_id=conversation_id, + user_audio_pcm=user_audio_pcm, + applied_converter_identifiers=applied_identifiers, + turn_result=turn_result, + ) + state.executed_turns += 1 + + def _snapshot_user_audio(self, state: _BargeInRunState) -> bytes: + """ + Snapshot the accumulated user PCM and clear the buffer for the next turn. + + Returns: + Snapshot of buffered PCM bytes prior to clearing. + """ + snapshot = bytes(state.raw_buffer) + state.raw_buffer.clear() + return snapshot + + def _build_result( + self, + *, + state: _BargeInRunState, + context: BargeInAttackContext[Any], + ) -> AttackResult: + """ + Assemble the final ``AttackResult`` from accumulated run state. + + Returns: + ``AttackResult`` with the last assistant message, executed turn count, and outcome reason. + """ + if state.executed_turns == 0: + outcome_reason: str | None = "No assistant turns completed (server VAD did not commit any user audio)" + else: + outcome_reason = f"{state.executed_turns} assistant turn(s) completed; no scorer configured" + + return AttackResult( + conversation_id=context.conversation_id, + objective=context.objective, + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + last_response=(state.last_assistant_message.message_pieces[0] if state.last_assistant_message else None), + last_score=None, + related_conversations=context.related_conversations, + outcome=AttackOutcome.UNDETERMINED, + outcome_reason=outcome_reason, + executed_turns=state.executed_turns, + labels=context.memory_labels, + ) + + async def _wait_for_pending_turns_async(self, turn_tasks: list[asyncio.Task[None]]) -> None: + """ + Wait for any in-flight VAD-committed turn tasks to finish, with a safety timeout. + + Returns as soon as all known turn tasks complete (or the cap elapses, whichever + comes first). The timeout is a safety net for stuck turns; the common case is to + return immediately once the last turn's persistence finishes. + + Args: + turn_tasks: Task handles for every ``on_committed`` invocation launched so far. + Tasks added after this method starts are not waited on; the dispatcher + callback machinery makes this race vanishingly unlikely in practice. + """ + if not turn_tasks: + return + try: + await asyncio.wait_for( + asyncio.gather(*turn_tasks, return_exceptions=True), + timeout=self._MAX_POST_STREAM_WAIT_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timed out after {self._MAX_POST_STREAM_WAIT_SECONDS}s waiting for in-flight turn tasks to " + "finish; teardown will cancel them. Increase _MAX_POST_STREAM_WAIT_SECONDS if responses " + "regularly take longer." + ) + + async def _persist_turn_async( + self, + *, + target: RealtimeTarget, + conversation_id: str, + user_audio_pcm: bytes, + applied_converter_identifiers: list[ComponentIdentifier], + turn_result: RealtimeTargetResult, + ) -> Message: + """ + Persist the user+assistant Message pair for one completed turn to CentralMemory. + + Returns: + The assistant Message so callers can surface it as ``last_response``. + """ + user_audio_path = await target.save_audio( + user_audio_pcm, + num_channels=1, + sample_width=2, + sample_rate=_REALTIME_SAMPLE_RATE_HZ, + ) + user_piece = MessagePiece( + role="user", + original_value=user_audio_path, + original_value_data_type="audio_path", + converted_value=user_audio_path, + converted_value_data_type="audio_path", + conversation_id=conversation_id, + ) + user_piece.converter_identifiers.extend(applied_converter_identifiers) + user_message = Message(message_pieces=[user_piece]) + + response_audio_path = await target.save_audio( + turn_result.audio_bytes, + num_channels=1, + sample_width=2, + sample_rate=_REALTIME_SAMPLE_RATE_HZ, + ) + text_piece = construct_response_from_request( + request=user_piece, + response_text_pieces=[turn_result.flatten_transcripts()], + response_type="text", + ).message_pieces[0] + audio_piece = construct_response_from_request( + request=user_piece, + response_text_pieces=[response_audio_path], + response_type="audio_path", + ).message_pieces[0] + if turn_result.interrupted: + text_piece.prompt_metadata["interrupted"] = True + audio_piece.prompt_metadata["interrupted"] = True + assistant_message = Message(message_pieces=[text_piece, audio_piece]) + + memory = CentralMemory.get_memory_instance() + memory.add_message_to_memory(request=user_message) + memory.add_message_to_memory(request=assistant_message) + return assistant_message diff --git a/pyrit/prompt_normalizer/__init__.py b/pyrit/prompt_normalizer/__init__.py index fa030605f7..dd1179b8b4 100644 --- a/pyrit/prompt_normalizer/__init__.py +++ b/pyrit/prompt_normalizer/__init__.py @@ -8,12 +8,14 @@ including converter configurations and request handling. """ +from pyrit.prompt_normalizer.audio_stream_normalizer import AudioStreamNormalizer from pyrit.prompt_normalizer.normalizer_request import NormalizerRequest from pyrit.prompt_normalizer.prompt_converter_configuration import PromptConverterConfiguration from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer __all__ = [ - "PromptNormalizer", - "PromptConverterConfiguration", + "AudioStreamNormalizer", "NormalizerRequest", + "PromptConverterConfiguration", + "PromptNormalizer", ] diff --git a/pyrit/prompt_normalizer/audio_stream_normalizer.py b/pyrit/prompt_normalizer/audio_stream_normalizer.py new file mode 100644 index 0000000000..350de64780 --- /dev/null +++ b/pyrit/prompt_normalizer/audio_stream_normalizer.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Normalizer for streaming audio: raw PCM in, converter-transformed PCM out.""" + +from __future__ import annotations + +import os +import tempfile +import wave +from typing import TYPE_CHECKING + +from pyrit.exceptions import ( + ComponentRole, + execution_context, + get_execution_context, +) + +if TYPE_CHECKING: + from pyrit.identifiers import ComponentIdentifier + from pyrit.prompt_normalizer.prompt_converter_configuration import ( + PromptConverterConfiguration, + ) + + +class AudioStreamNormalizer: + """ + Normalizer that adapts raw PCM audio for streaming targets. + + Streaming attacks hold mid-turn PCM rather than a ``Message``; this class bridges + raw PCM to PyRIT's file-based converter ecosystem by writing the audio to a + temporary WAV, running converters via ``convert_tokens_async`` with + ``input_type="audio_path"``, and reading the resulting PCM back. Subclass to + customize bridging behavior (alternate format adaptation, parallelism, etc.). + """ + + def __init__(self, *, start_token: str = "⟪", end_token: str = "⟫") -> None: + """Initialize with optional token delimiters passed through to converters.""" + self._start_token = start_token + self._end_token = end_token + + async def normalize_async( + self, + *, + pcm_bytes: bytes, + sample_rate: int, + converter_configurations: list[PromptConverterConfiguration], + ) -> tuple[bytes, list[ComponentIdentifier]]: + """ + Run ``converter_configurations`` against ``pcm_bytes`` via a temp WAV bridge. + + Args: + pcm_bytes: Raw PCM16 mono audio. + sample_rate: Sample rate in Hz. + converter_configurations: Same shape consumed by ``PromptNormalizer.convert_values``. + + Returns: + ``(converted_pcm, identifiers_that_ran)``. + + Raises: + ValueError: If converter output is not mono PCM16 at ``sample_rate``. + """ + if not pcm_bytes: + return pcm_bytes, [] + + # Drop configs that don't target audio_path so we never enter the WAV bridge when + # nothing applicable will run (e.g. text-only converters configured on a streaming attack). + applicable_configs = [ + config + for config in converter_configurations + if not config.prompt_data_types_to_apply or "audio_path" in config.prompt_data_types_to_apply + ] + if not applicable_configs: + return pcm_bytes, [] + + identifiers: list[ComponentIdentifier] = [] + + with tempfile.TemporaryDirectory() as tmpdir: + current_path = os.path.join(tmpdir, "streaming_input.wav") + with wave.open(current_path, "wb") as wav_out: + wav_out.setnchannels(1) + wav_out.setsampwidth(2) + wav_out.setframerate(sample_rate) + wav_out.writeframes(pcm_bytes) + + for config in applicable_configs: + for converter in config.converters: + outer_context = get_execution_context() + with execution_context( + component_role=ComponentRole.CONVERTER, + attack_strategy_name=outer_context.attack_strategy_name if outer_context else None, + attack_identifier=outer_context.attack_identifier if outer_context else None, + component_identifier=converter.get_identifier(), + objective_target_conversation_id=( + outer_context.objective_target_conversation_id if outer_context else None + ), + ): + result = await converter.convert_tokens_async( + prompt=current_path, + input_type="audio_path", + start_token=self._start_token, + end_token=self._end_token, + ) + current_path = result.output_text + identifiers.append(converter.get_identifier()) + + with wave.open(current_path, "rb") as wav_in: + if wav_in.getnchannels() != 1 or wav_in.getsampwidth() != 2 or wav_in.getframerate() != sample_rate: + raise ValueError( + "Converter output incompatible with streaming target: " + f"expected mono PCM16 @ {sample_rate} Hz, got channels={wav_in.getnchannels()} " + f"sampwidth={wav_in.getsampwidth()} rate={wav_in.getframerate()}." + ) + return wav_in.readframes(wav_in.getnframes()), identifiers diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 82f897c156..114bfa4893 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -19,6 +19,7 @@ ) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.realtime_audio import ServerVadConfig from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, CapabilityName, @@ -101,6 +102,7 @@ def __getattr__(name: str) -> object: "PromptShieldTarget", "PromptTarget", "RealtimeTarget", + "ServerVadConfig", "TargetCapabilities", "TargetConfiguration", "TargetRequirements", diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 859d07d428..b7ba4a5fe5 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -149,6 +149,7 @@ def _permissive_configuration( supports_json_output=True, supports_editable_history=True, supports_system_prompt=True, + supports_streaming_barge_in=True, input_modalities=merged_modalities, ) # Rebuild a fresh configuration from the instance's native capabilities so diff --git a/pyrit/prompt_target/common/realtime_audio.py b/pyrit/prompt_target/common/realtime_audio.py new file mode 100644 index 0000000000..fb2d989d25 --- /dev/null +++ b/pyrit/prompt_target/common/realtime_audio.py @@ -0,0 +1,233 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Shared types for realtime audio prompt targets.""" + +import asyncio +import contextlib +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ServerVadConfig: + """Server-side voice activity detection (VAD) tuning for realtime audio targets.""" + + threshold: float = 0.4 + prefix_padding_ms: int = 200 + silence_duration_ms: int = 1500 + + def __post_init__(self) -> None: + """ + Validate VAD tuning values. + + Raises: + ValueError: If any field is outside its valid range. + """ + if not 0.0 <= self.threshold <= 1.0: + raise ValueError(f"threshold must be in [0.0, 1.0], got {self.threshold}") + if self.prefix_padding_ms < 0: + raise ValueError(f"prefix_padding_ms must be non-negative, got {self.prefix_padding_ms}") + if self.silence_duration_ms < 0: + raise ValueError(f"silence_duration_ms must be non-negative, got {self.silence_duration_ms}") + + +@dataclass +class RealtimeTargetResult: + """Result of a Realtime API turn: delivered audio, transcripts, and interruption status.""" + + audio_bytes: bytes = b"" + transcripts: list[str] = field(default_factory=list) + interrupted: bool = False + + def flatten_transcripts(self) -> str: + """Return all transcript deltas concatenated into a single string.""" + return "".join(self.transcripts) + + +@dataclass +class RealtimeTurnState: + """Mutable per-turn state assembled by the dispatcher from incoming events.""" + + completion: asyncio.Future[RealtimeTargetResult] + is_responding: bool = False + delivered_audio: bytearray = field(default_factory=bytearray) + delivered_transcripts: list[str] = field(default_factory=list) + current_item_id: str | None = None + last_response_id: str | None = None + interrupted: bool = False + + +@dataclass(frozen=True) +class CommittedEvent: + """Payload passed to ``on_user_audio_committed`` callbacks when server VAD commits.""" + + item_id: str + audio_start_ms: int | None = None + + +class RealtimeEventDispatcher(ABC): + """ + Owns a realtime connection's event stream and routes events to the active turn. + + Provider-specific event routing and cancel logic are isolated to the abstract methods. + """ + + def __init__( + self, + *, + connection: Any, + on_user_audio_committed: Callable[[CommittedEvent], Coroutine[Any, Any, None]] | None = None, + ) -> None: + """ + Args: + connection: An open realtime connection exposing an async iterator + of server events. The dispatcher owns reading from it. + on_user_audio_committed: Optional callback fired when the server + commits a user audio buffer (e.g. server VAD finalizing a turn). + Invoked as a background task so converter work in the callback + does not block the dispatch loop. Default None disables it. + """ + self._connection = connection + self._on_user_audio_committed = on_user_audio_committed + self._current_turn: RealtimeTurnState | None = None + self._task: asyncio.Task[None] | None = None + self._callback_tasks: set[asyncio.Task[None]] = set() + self._failure: BaseException | None = None + + @property + def failure(self) -> BaseException | None: + """ + The exception that killed the dispatch loop, or None if it is still healthy. + + Set when the outer event iterator raises. Callers (e.g. ``BargeInAttack``) + poll this between operations to detect a dead connection without needing a + callback. Once set, ``stop()`` should be called and the attack torn down. + """ + return self._failure + + async def start(self) -> None: + """Start the background dispatch task. Idempotent.""" + if self._task is None: + self._task = asyncio.create_task(self._dispatch_loop()) + + async def stop(self) -> None: + """ + Cancel the background dispatch task and release the reference. + + In-flight callback tasks are cancelled and awaited (with exception + suppression) so they don't deadlock waiting on the turn future that the + now-dead dispatch loop would have resolved. + """ + if self._task is not None: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._task + self._task = None + if self._callback_tasks: + pending = list(self._callback_tasks) + self._callback_tasks.clear() + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + def register_turn(self, state: RealtimeTurnState) -> None: + """ + Bind a new turn as the active turn. + + Args: + state (RealtimeTurnState): The turn whose completion future will be + resolved when this turn ends. + + Raises: + RuntimeError: If another turn is already active on this dispatcher. + """ + if self._current_turn is not None and not self._current_turn.completion.done(): + raise RuntimeError("Another turn is already active on this dispatcher") + self._current_turn = state + + async def _dispatch_loop(self) -> None: + """ + Consume events from the connection and route each to the active turn. + + The router is called for every event with the current turn (which may + be None during the gap between turns). Concrete routers are expected to + handle ``state is None`` for input-side events that need no turn state + and return early on output-side events when no turn is registered. + + Raises: + asyncio.CancelledError: Propagated when ``stop()`` cancels the task. + """ + try: + async for event in self._connection: + turn = self._current_turn + if turn is not None and turn.completion.done(): + turn = None + try: + await self._route_event(event=event, state=turn) + except Exception as e: + logger.exception(f"Realtime event router raised: {e}") + if turn is not None and not turn.completion.done(): + turn.completion.set_exception(e) + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"Realtime dispatch loop crashed: {e}") + self._failure = e + turn = self._current_turn + if turn is not None and not turn.completion.done(): + turn.completion.set_exception(e) + + def _fire_committed_callback(self, event: CommittedEvent) -> None: + """ + Schedule the ``on_user_audio_committed`` callback as a background task. + + Tracks the resulting task so ``stop()`` can wait for it to finish. + """ + if self._on_user_audio_committed is None: + return + task = asyncio.create_task(self._on_user_audio_committed(event)) + self._callback_tasks.add(task) + task.add_done_callback(self._callback_tasks.discard) + + @abstractmethod + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + """ + Route a single provider-specific event. + + Concrete implementations: + - When the event is output-side (response lifecycle, audio/transcript + deltas, etc.) and ``state`` is non-None, mutate ``state`` and resolve + ``state.completion`` at end-of-turn or on interruption. + - When ``state`` is None (no active turn) or + ``state.completion.done()``, output-side events should be dropped. + - When the event is input-side (e.g. ``input_audio_buffer.committed``), + fire any subscribed callback via ``self._fire_committed_callback(...)``. + These callbacks may run regardless of ``state``. + - On error events, resolve ``state.completion`` via ``set_exception`` + when a turn is active. + + Args: + event: A single provider-specific event from the connection iterator. + state (RealtimeTurnState | None): The currently-active turn, or None + if no turn is registered (e.g. between turns in a streaming + session). + """ + + @abstractmethod + async def _cancel(self, *, state: RealtimeTurnState) -> None: + """ + Send provider-specific cancel and truncate events for the in-flight response. + + Must set ``state.interrupted = True`` even on wire-call failure so callers + can tell the turn was cut short. Must not resolve ``state.completion``; + that is the dispatcher's responsibility. + + Args: + state (RealtimeTurnState): The turn whose response should be cancelled. + """ diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 6ae9ed69e2..b578d6eefd 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -24,6 +24,7 @@ class CapabilityName(str, Enum): JSON_OUTPUT = "supports_json_output" EDITABLE_HISTORY = "supports_editable_history" SYSTEM_PROMPT = "supports_system_prompt" + STREAMING_BARGE_IN = "supports_streaming_barge_in" class UnsupportedCapabilityBehavior(str, Enum): @@ -138,6 +139,13 @@ class attribute. Users can override individual capabilities per instance # Whether the target natively supports system prompts. supports_system_prompt: bool = False + # Whether the target supports the streaming barge-in API: pushing user audio chunks + # via ``push_audio_chunk_async``, subscribing to user-audio-committed events via + # ``subscribe_events_async``, swapping committed items via + # ``delete_conversation_item_async`` + ``insert_user_audio_async``, and triggering + # responses via ``request_response_async``. Required by ``BargeInAttack``. + supports_streaming_barge_in: bool = False + # The input modalities supported by the target (e.g., "text", "image"). input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 3deffe6287..0f9471bb69 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -6,8 +6,8 @@ import logging import re import wave -from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, Literal, Optional from openai import AsyncOpenAI @@ -22,11 +22,21 @@ data_serializer_factory, ) from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeEventDispatcher, + RealtimeTargetResult, + RealtimeTurnState, + ServerVadConfig, +) from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.openai.openai_target import OpenAITarget +if TYPE_CHECKING: + from pyrit.prompt_normalizer import AudioStreamNormalizer + logger = logging.getLogger(__name__) # Voices supported by the OpenAI Realtime API. @@ -35,29 +45,6 @@ RealTimeVoice = Literal["alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse", "marin", "cedar"] -@dataclass -class RealtimeTargetResult: - """ - Represents the result of a Realtime API request, containing audio data and transcripts. - - Attributes: - audio_bytes: Raw audio data returned by the API - transcripts: List of text transcripts generated from the audio - """ - - audio_bytes: bytes = field(default_factory=lambda: b"") - transcripts: list[str] = field(default_factory=list) - - def flatten_transcripts(self) -> str: - """ - Flattens the list of transcripts into a single string. - - Returns: - A single string containing all transcripts concatenated together. - """ - return "".join(self.transcripts) - - class RealtimeTarget(OpenAITarget, PromptTarget): """ A prompt target for Azure OpenAI Realtime API. @@ -75,6 +62,7 @@ class RealtimeTarget(OpenAITarget, PromptTarget): supports_editable_history=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_streaming_barge_in=True, input_modalities=frozenset( { frozenset(["text"]), @@ -97,6 +85,8 @@ def __init__( voice: Optional[RealTimeVoice] = None, existing_convo: Optional[dict[str, Any]] = None, custom_configuration: Optional[TargetConfiguration] = None, + server_vad: bool | ServerVadConfig = False, + audio_normalizer: Optional["AudioStreamNormalizer"] = None, **kwargs: Any, ) -> None: """ @@ -120,6 +110,15 @@ def __init__( existing_convo (dict[str, websockets.WebSocketClientProtocol], Optional): Existing conversations. custom_configuration (TargetConfiguration, Optional): Override the default configuration for this target instance. Defaults to None. + server_vad (bool | ServerVadConfig): Server-side voice activity detection (VAD). + ``False`` (default) keeps the existing atomic send/receive behavior. + ``True`` enables VAD with default tuning. + Pass a ``ServerVadConfig`` to enable with custom tuning. Streaming/interruption plumbing + arrives in subsequent changes; this currently only affects the emitted session config. + audio_normalizer (AudioStreamNormalizer, Optional): Normalizer applied to raw PCM + mid-turn before it is sent back into the conversation. Defaults to a stock + ``AudioStreamNormalizer`` that bridges PCM to PyRIT's file-based converter + pipeline. Override to plug in custom format adaptation. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` @@ -130,6 +129,17 @@ def __init__( self._existing_conversation = existing_convo if existing_convo is not None else {} self._realtime_client: Optional[AsyncOpenAI] = None + if isinstance(server_vad, ServerVadConfig): + self._server_vad: Optional[ServerVadConfig] = server_vad + elif server_vad: + self._server_vad = ServerVadConfig() + else: + self._server_vad = None + + from pyrit.prompt_normalizer import AudioStreamNormalizer + + self.audio_normalizer: AudioStreamNormalizer = audio_normalizer or AudioStreamNormalizer() + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_REALTIME_MODEL" self.endpoint_environment_variable = "OPENAI_REALTIME_ENDPOINT" @@ -241,7 +251,7 @@ def _get_openai_client(self) -> AsyncOpenAI: return self._realtime_client - async def connect(self, conversation_id: str) -> Any: + async def connect_async(self, conversation_id: str) -> Any: """ Connect to Realtime API using AsyncOpenAI client and return the realtime connection. @@ -290,6 +300,16 @@ def _set_system_prompt_and_config_vars(self, system_prompt: str) -> dict[str, An }, } + if self._server_vad is not None: + session_config["audio"]["input"]["turn_detection"] = { # type: ignore[ty:invalid-assignment] + "type": "server_vad", + "threshold": self._server_vad.threshold, + "prefix_padding_ms": self._server_vad.prefix_padding_ms, + "silence_duration_ms": self._server_vad.silence_duration_ms, + "create_response": True, + "interrupt_response": True, + } + if self.voice: session_config["audio"]["output"]["voice"] = self.voice # type: ignore[ty:invalid-assignment] @@ -359,7 +379,7 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me message = normalized_conversation[-1] conversation_id = message.message_pieces[0].conversation_id if conversation_id not in self._existing_conversation: - connection = await self.connect(conversation_id=conversation_id) + connection = await self.connect_async(conversation_id=conversation_id) self._existing_conversation[conversation_id] = connection # Only send config when creating a new connection @@ -393,6 +413,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me request=request, response_text_pieces=[output_audio_path], response_type="audio_path" ).message_pieces[0] + if result.interrupted: + text_response_piece.prompt_metadata["interrupted"] = True + audio_response_piece.prompt_metadata["interrupted"] = True + response_entry = Message(message_pieces=[text_response_piece, audio_response_piece]) return [response_entry] @@ -476,6 +500,230 @@ async def send_response_create(self, conversation_id: str) -> None: connection = self._get_connection(conversation_id=conversation_id) await connection.response.create() + async def push_audio_chunk_async(self, *, connection: Any, pcm_bytes: bytes) -> None: + """ + Append a single PCM16 mono @ 24 kHz audio chunk to the server's input buffer. + + Used by streaming-style callers (e.g. ``BargeInAttack``) that source chunks + from an iterator and want to control commit timing externally. Server VAD, + when enabled on the session, decides when to commit and fire response logic. + Empty buffers are accepted as no-ops. + + Args: + connection: Active Realtime API connection from ``self.connect()``. + pcm_bytes: Raw PCM16 mono audio for this chunk. + """ + if not pcm_bytes: + return + audio_b64 = base64.b64encode(pcm_bytes).decode("ascii") + await connection.input_audio_buffer.append(audio=audio_b64) + + async def insert_user_audio_async(self, *, connection: Any, pcm_bytes: bytes) -> None: + """ + Insert a user message containing the given PCM16 mono @ 24 kHz audio into the conversation. + + Use for the convert-on-commit dance — after deleting the server's raw user item, + the attack inserts the converted audio via this method before manually triggering + ``response.create``. + + Args: + connection: Active Realtime API connection. + pcm_bytes: Converted PCM16 mono audio. + """ + audio_b64 = base64.b64encode(pcm_bytes).decode("ascii") + await connection.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [{"type": "input_audio", "audio": audio_b64}], + } + ) + + async def insert_user_text_async(self, *, connection: Any, text: str) -> None: + """ + Insert a user message containing the given text into the conversation. + + Lets streaming attacks mix text turns into an otherwise audio-driven session. + The caller is responsible for triggering ``response.create`` after insertion. + + Args: + connection: Active Realtime API connection. + text: User-side text content. + """ + await connection.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}], + } + ) + + async def delete_conversation_item_async(self, *, connection: Any, item_id: str) -> None: + """ + Delete a conversation item by id (e.g. the server's raw user audio item). + + Used during convert-on-commit to remove the raw audio item before replacing + it with a converted one. Errors are propagated; callers that want best-effort + deletion should wrap with ``contextlib.suppress``. + + Args: + connection: Active Realtime API connection. + item_id: Server-assigned item id to delete. + """ + await connection.conversation.item.delete(item_id=item_id) + + async def swap_user_audio_async( + self, + *, + connection: Any, + committed_event: CommittedEvent, + converted_pcm: bytes, + ) -> None: + """ + Replace the server's just-committed user audio with converted PCM. + + Inserts ``converted_pcm`` as a new user item and best-effort deletes the original + item identified by ``committed_event``. Hides OpenAI's item-id concept from + callers so streaming attacks can stay provider-agnostic. + + Args: + connection: Active Realtime API connection. + committed_event: Payload received in the on-committed callback. + converted_pcm: PCM16 mono @ 24 kHz audio to insert in place of the original. + """ + await self.insert_user_audio_async(connection=connection, pcm_bytes=converted_pcm) + try: + await self.delete_conversation_item_async(connection=connection, item_id=committed_event.item_id) + except Exception as e: + logger.warning(f"conversation.item.delete failed for {committed_event.item_id}: {e}") + + async def subscribe_events_async( + self, + *, + connection: Any, + on_user_audio_committed: (Callable[[CommittedEvent], Coroutine[Any, Any, None]] | None) = None, + ) -> RealtimeEventDispatcher: + """ + Start consuming events from the connection and route them via the OpenAI dispatcher. + + Streaming-style callers (``BargeInAttack``) use this to receive normalized + events (``user_audio_committed``). The returned dispatcher exposes + ``stop()`` to tear down the background task and drain in-flight callback + tasks, and a ``failure`` property that callers can poll between operations + to detect a dead dispatch loop (e.g. websocket closed). Callers should + call ``stop()`` before closing the connection. + + Args: + connection: Active Realtime API connection from ``self.connect()``. + on_user_audio_committed: Async callback fired when server VAD finalizes + a user audio buffer. Called as a background task. + + Returns: + The started dispatcher. Pass it to ``request_response_async`` for turn + futures, poll ``failure`` for dispatch-loop errors, and call ``stop()`` + to tear it down. + """ + dispatcher = _OpenAIRealtimeDispatcher( + connection=connection, + on_user_audio_committed=on_user_audio_committed, + ) + await dispatcher.start() + return dispatcher + + async def request_response_async( + self, + *, + connection: Any, + dispatcher: RealtimeEventDispatcher, + ) -> asyncio.Future[RealtimeTargetResult]: + """ + Trigger ``response.create`` and return a future that resolves when the turn ends. + + Constructs a fresh ``RealtimeTurnState``, binds it to the dispatcher as the + active turn, then sends ``response.create``. The dispatcher resolves the + returned future via ``response.done`` (with ``interrupted=False``) or via + the barge-in cancel path (with ``interrupted=True``). + + Args: + connection: Active Realtime API connection. + dispatcher: Subscription handle previously returned by + ``subscribe_events_async``. Must not have another turn pending. + + Returns: + Future resolved with the assembled ``RealtimeTargetResult`` when this + turn ends (normally or via barge-in). + + Raises: + RuntimeError: If another turn is already pending on the dispatcher. + """ + state = RealtimeTurnState(completion=asyncio.get_running_loop().create_future()) + dispatcher.register_turn(state) + await connection.response.create() + return state.completion + + async def send_streaming_session_config_async(self, *, connection: Any, system_prompt: str) -> None: + """ + Configure the realtime session for streaming use: server VAD with manual response creation. + + Emits the same session config as the atomic path except ``turn_detection.create_response`` + is forced to False so the streaming attack can swap the raw user audio item for converted + audio before triggering ``response.create``. + + Args: + connection: Active Realtime API connection. + system_prompt: System prompt for the realtime session. + + Raises: + ValueError: If the target was constructed without server VAD. + """ + if self._server_vad is None: + raise ValueError( + "send_streaming_session_config_async requires server VAD; " + "construct RealtimeTarget(server_vad=True) or pass a ServerVadConfig." + ) + config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) + turn_detection = config.get("audio", {}).get("input", {}).get("turn_detection") + if turn_detection is not None: + turn_detection["create_response"] = False + await connection.session.update(session=config) + + async def _stream_pcm_async( + self, + *, + connection: Any, + pcm_bytes: bytes, + commit: bool, + chunk_ms: int = 100, + sample_rate: int = 24000, + ) -> None: + """ + Stream raw PCM16 audio to the Realtime API as ``input_audio_buffer.append`` chunks. + + Operates on raw PCM bytes (not WAV) so this helper can back both the + WAV-file path and future per-frame streaming consumers (e.g. browser audio + forwarded by a GUI backend). Caller decides whether to manually commit; + server VAD commits automatically when enabled. + + Args: + connection: Active Realtime API connection from ``self.connect()``. + pcm_bytes (bytes): Raw PCM16 mono audio. Empty buffers are accepted + and result in zero appends. + commit (bool): When True, sends ``input_audio_buffer.commit`` after the + final chunk. Pass False when server VAD is committing automatically. + chunk_ms (int): Milliseconds of audio per chunk. Defaults to 100. + sample_rate (int): PCM sample rate in Hz. Defaults to 24000. + """ + bytes_per_sample = 2 # PCM16 + chunk_size = (chunk_ms * sample_rate * bytes_per_sample) // 1000 + + for offset in range(0, len(pcm_bytes), chunk_size): + chunk = pcm_bytes[offset : offset + chunk_size] + audio_b64 = base64.b64encode(chunk).decode("ascii") + await connection.input_audio_buffer.append(audio=audio_b64) + + if commit: + await connection.input_audio_buffer.commit() + async def receive_events(self, conversation_id: str) -> RealtimeTargetResult: """ Continuously receive events from the OpenAI Realtime API connection. @@ -806,3 +1054,121 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> This implementation exists to satisfy the abstract base class requirement. """ raise NotImplementedError("RealtimeTarget uses receive_events for message construction") + + +class _OpenAIRealtimeDispatcher(RealtimeEventDispatcher): + """ + Concrete ``RealtimeEventDispatcher`` for the OpenAI Realtime API. + + Routes OpenAI server events into the active ``RealtimeTurnState`` and issues + ``response.cancel`` plus ``conversation.item.truncate`` when interrupted. + """ + + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + """Route an OpenAI Realtime event to the active turn or to an input-side callback.""" + event_type = getattr(event, "type", "") + + # Input-side events fire callbacks regardless of whether a turn is registered. + if event_type == "input_audio_buffer.committed": + item_id = getattr(event, "item_id", None) + if item_id is None: + return + self._fire_committed_callback( + CommittedEvent( + item_id=item_id, + audio_start_ms=getattr(event, "audio_start_ms", None), + ) + ) + # Fall through: also include the bookkeeping below (none currently uses committed). + return + + # Remaining events are output-side and mutate per-turn state; drop if no turn. + if state is None or state.completion.done(): + return + + if event_type == "response.created": + state.is_responding = True + response = getattr(event, "response", None) + if response is not None: + state.last_response_id = getattr(response, "id", None) + return + + if event_type in ("response.output_item.added", "response.output_item.created"): + item = getattr(event, "item", None) + if item is not None: + state.current_item_id = getattr(item, "id", None) + return + + if event_type in ("response.audio.delta", "response.output_audio.delta"): + delta = getattr(event, "delta", "") + if delta: + state.delivered_audio.extend(base64.b64decode(delta)) + return + + if event_type in ("response.audio_transcript.delta", "response.output_audio_transcript.delta"): + delta = getattr(event, "delta", "") + if delta: + state.delivered_transcripts.append(delta) + return + + if event_type == "response.done": + response = getattr(event, "response", None) + done_response_id = getattr(response, "id", None) if response is not None else None + if state.last_response_id is not None and done_response_id != state.last_response_id: + # Stale event from a cancelled response; drop without resolving. + return + state.is_responding = False + state.completion.set_result( + RealtimeTargetResult( + audio_bytes=bytes(state.delivered_audio), + transcripts=list(state.delivered_transcripts), + ) + ) + return + + if event_type == "input_audio_buffer.speech_started" and state.is_responding: + await self._cancel(state=state) + state.is_responding = False + state.completion.set_result( + RealtimeTargetResult( + audio_bytes=bytes(state.delivered_audio), + transcripts=list(state.delivered_transcripts), + interrupted=True, + ) + ) + return + + if event_type == "error": + error = getattr(event, "error", None) + message = getattr(error, "message", "unknown") if error is not None else "unknown" + state.completion.set_exception(RuntimeError(f"Realtime API error: {message}")) + return + + async def _cancel(self, *, state: RealtimeTurnState) -> None: + """ + Truncate the in-flight response's conversation item to what was actually delivered. + + The server auto-cancels the response when it detects new speech, so we only need to + trim the conversation history to match the audio we received. + + Marks ``state.interrupted = True`` even when the truncate call fails. + Does not resolve ``state.completion``; the caller (``_route_event``) does that. + + Args: + state (RealtimeTurnState): The turn whose response should be cancelled. + """ + if state.current_item_id is not None: + # PCM16 @ 24 kHz: 48 bytes per millisecond. + audio_end_ms = len(state.delivered_audio) // 48 + try: + await self._connection.conversation.item.truncate( + item_id=state.current_item_id, + content_index=0, + audio_end_ms=audio_end_ms, + ) + except Exception as e: + logger.warning( + f"conversation.item.truncate failed for item {state.current_item_id} " + f"(audio_end_ms={audio_end_ms}): {e}" + ) + state.interrupted = True diff --git a/tests/unit/executor/attack/streaming/__init__.py b/tests/unit/executor/attack/streaming/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/executor/attack/streaming/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/executor/attack/streaming/test_barge_in.py b/tests/unit/executor/attack/streaming/test_barge_in.py new file mode 100644 index 0000000000..b8f8cb81b7 --- /dev/null +++ b/tests/unit/executor/attack/streaming/test_barge_in.py @@ -0,0 +1,628 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for ``BargeInAttack`` and supporting helpers.""" + +from __future__ import annotations + +import asyncio +import os +import tempfile +import wave +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.executor.attack import BargeInAttack, BargeInAttackContext +from pyrit.executor.attack.core import AttackConverterConfig, AttackParameters +from pyrit.identifiers import ComponentIdentifier +from pyrit.models import AttackOutcome +from pyrit.prompt_normalizer import AudioStreamNormalizer, PromptConverterConfiguration +from pyrit.prompt_target import RealtimeTarget +from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeTargetResult, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + +_CLEAN_ENV = {"OPENAI_REALTIME_UNDERLYING_MODEL": ""} + + +@pytest.fixture +@patch.dict("os.environ", _CLEAN_ENV) +def vad_target(sqlite_instance): + return RealtimeTarget(api_key="test_key", endpoint="wss://test_url", model_name="test", server_vad=True) + + +async def _aiter(chunks: list[bytes]) -> AsyncIterator[bytes]: + for c in chunks: + yield c + + +def _attack_context(*, audio_chunks: AsyncIterator[bytes], objective: str = "obj") -> BargeInAttackContext[Any]: + return BargeInAttackContext( + params=AttackParameters(objective=objective), + audio_chunks=audio_chunks, + ) + + +def _mock_connection() -> AsyncMock: + connection = AsyncMock() + connection.input_audio_buffer.append = AsyncMock() + connection.conversation.item.create = AsyncMock() + connection.conversation.item.delete = AsyncMock() + connection.response.create = AsyncMock() + connection.session.update = AsyncMock() + connection.close = AsyncMock() + return connection + + +# ---- Construction validation ----------------------------------------------------------------- + + +@patch.dict("os.environ", _CLEAN_ENV) +def test_constructor_rejects_target_without_streaming_capability(sqlite_instance): + """A target whose capabilities lack STREAMING_BARGE_IN must be rejected at construction.""" + from pyrit.prompt_target import OpenAIChatTarget + + no_streaming = OpenAIChatTarget(api_key="k", endpoint="https://x", model_name="m") + with pytest.raises(Exception, match="streaming_barge_in"): + BargeInAttack(objective_target=no_streaming) + + +def test_constructor_succeeds_with_vad_target(vad_target): + """A RealtimeTarget declares STREAMING_BARGE_IN — construction succeeds.""" + attack = BargeInAttack(objective_target=vad_target) + assert attack.get_objective_target() is vad_target + + +def test_constructor_succeeds_even_without_server_vad_enabled(sqlite_instance): + """Capability check passes; server VAD is a runtime config concern surfaced when used.""" + with patch.dict("os.environ", _CLEAN_ENV): + no_vad = RealtimeTarget(api_key="k", endpoint="wss://test_url", model_name="test") + # Construction succeeds — capability is about the target type, not server_vad config. + attack = BargeInAttack(objective_target=no_vad) + assert attack.get_objective_target() is no_vad + + +# ---- Context validation ---------------------------------------------------------------------- + + +async def test_validate_context_requires_objective(vad_target): + attack = BargeInAttack(objective_target=vad_target) + ctx = BargeInAttackContext( + params=AttackParameters(objective=""), + audio_chunks=_aiter([b"\x00" * 96]), + ) + with pytest.raises(ValueError, match="objective"): + attack._validate_context(context=ctx) + + +async def test_validate_context_requires_audio_chunks(vad_target): + attack = BargeInAttack(objective_target=vad_target) + ctx = BargeInAttackContext( + params=AttackParameters(objective="o"), + audio_chunks=None, + ) + with pytest.raises(ValueError, match="audio_chunks"): + attack._validate_context(context=ctx) + + +# ---- Streaming loop end-to-end --------------------------------------------------------------- + + +async def test_perform_async_streams_chunks_and_tears_down(vad_target): + """Happy path: connect, send config, subscribe, push chunks, stop, close — no commits.""" + attack = BargeInAttack(objective_target=vad_target) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + dispatcher = AsyncMock() + dispatcher.stop = AsyncMock() + vad_target.subscribe_events_async = AsyncMock(return_value=dispatcher) + + chunks = [b"\x11" * 480, b"\x22" * 480, b"\x33" * 240] + ctx = _attack_context(audio_chunks=_aiter(chunks)) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + result = await attack._perform_async(context=ctx) + + vad_target.connect_async.assert_awaited_once_with(conversation_id=ctx.conversation_id) + vad_target.send_streaming_session_config_async.assert_awaited_once() + vad_target.subscribe_events_async.assert_awaited_once() + assert vad_target.push_audio_chunk_async.await_count == len(chunks) + pushed = [call.kwargs["pcm_bytes"] for call in vad_target.push_audio_chunk_async.await_args_list] + assert pushed == chunks + dispatcher.stop.assert_awaited_once() + connection.close.assert_awaited_once() + assert result.executed_turns == 0 + assert result.outcome == AttackOutcome.UNDETERMINED + + +async def test_perform_async_fires_request_response_on_commit(vad_target): + """A commit event must drive request_response_async and increment the turn counter.""" + attack = BargeInAttack(objective_target=vad_target) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + + # Capture the registered on_user_audio_committed so we can drive it. + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + + expected = RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["hello"]) + expected_future: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + expected_future.set_result(expected) + vad_target.request_response_async = AsyncMock(return_value=expected_future) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield b"\x00" * 480 + # Drive a fake commit mid-stream. + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_1"))) + + ctx = _attack_context(audio_chunks=chunks_then_commit()) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + result = await attack._perform_async(context=ctx) + + vad_target.request_response_async.assert_awaited_once() + assert result.executed_turns == 1 + assert "1 assistant turn" in (result.outcome_reason or "") + + +async def test_perform_async_stops_dispatcher_even_on_exception(vad_target): + """If the chunk loop raises, dispatcher.stop() and connection.close() still run.""" + attack = BargeInAttack(objective_target=vad_target) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock(side_effect=RuntimeError("push exploded")) + dispatcher = AsyncMock() + vad_target.subscribe_events_async = AsyncMock(return_value=dispatcher) + + ctx = _attack_context(audio_chunks=_aiter([b"\x00" * 96])) + + with pytest.raises(RuntimeError, match="push exploded"): + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + await attack._perform_async(context=ctx) + + dispatcher.stop.assert_awaited_once() + connection.close.assert_awaited_once() + + +# ---- send_streaming_session_config_async (target-side helper added in R4a) ------------------- + + +async def test_send_streaming_session_config_async_emits_create_response_false(vad_target): + """The streaming session config must flip create_response to False on turn_detection.""" + connection = _mock_connection() + await vad_target.send_streaming_session_config_async(connection=connection, system_prompt="hi") + connection.session.update.assert_awaited_once() + config = connection.session.update.call_args.kwargs["session"] + assert config["audio"]["input"]["turn_detection"]["create_response"] is False + + +@patch.dict("os.environ", _CLEAN_ENV) +async def test_send_streaming_session_config_async_requires_server_vad(sqlite_instance): + """Without server VAD, sending streaming session config must raise.""" + no_vad = RealtimeTarget(api_key="k", endpoint="wss://test_url", model_name="test") + connection = _mock_connection() + with pytest.raises(ValueError, match="server VAD"): + await no_vad.send_streaming_session_config_async(connection=connection, system_prompt="hi") + + +# Placeholder for R4b tests + + +# ---- Convert-on-commit dance (R4b) ---------------------------------------------------------- + + +def _make_audio_converter(transformer, *, identifier_name: str = "MockAudioConverter"): + """Mock audio converter whose convert_tokens_async runs transformer(pcm) and emits a new WAV path.""" + converter = MagicMock() + converter.get_identifier = MagicMock( + return_value=ComponentIdentifier(class_name=identifier_name, class_module="tests.unit.mocks"), + ) + + async def _convert(*, prompt, input_type, start_token=None, end_token=None): + assert input_type == "audio_path" + with wave.open(prompt, "rb") as wf_in: + sample_rate = wf_in.getframerate() + pcm = wf_in.readframes(wf_in.getnframes()) + new_pcm = transformer(pcm) + out_dir = tempfile.mkdtemp() + out_path = os.path.join(out_dir, "out.wav") + with wave.open(out_path, "wb") as wf_out: + wf_out.setnchannels(1) + wf_out.setsampwidth(2) + wf_out.setframerate(sample_rate) + wf_out.writeframes(new_pcm) + result = MagicMock() + result.output_text = out_path + return result + + converter.convert_tokens_async = AsyncMock(side_effect=_convert) + return converter + + +def _converter_config(converters: list[Any]) -> AttackConverterConfig: + """Wrap a list of converters into an AttackConverterConfig.""" + return AttackConverterConfig( + request_converters=PromptConverterConfiguration.from_converters(converters=converters), + ) + + +async def test_perform_async_swaps_raw_item_when_converters_change_audio(vad_target): + """When converters change the audio, the attack must delete the raw item + insert converted.""" + bump = _make_audio_converter(lambda pcm: bytes((b + 1) & 0xFF for b in pcm)) + attack = BargeInAttack(objective_target=vad_target, attack_converter_config=_converter_config([bump])) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.delete_conversation_item_async = AsyncMock() + vad_target.insert_user_audio_async = AsyncMock() + + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + + result_future: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + result_future.set_result(RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["ok"])) + vad_target.request_response_async = AsyncMock(return_value=result_future) + + raw_chunk = b"\x05" * 96 # PCM16 sample-aligned + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield raw_chunk + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_99"))) + + ctx = BargeInAttackContext( + params=AttackParameters(objective="obj"), + audio_chunks=chunks_then_commit(), + ) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + result = await attack._perform_async(context=ctx) + + vad_target.delete_conversation_item_async.assert_awaited_once_with(connection=connection, item_id="raw_99") + vad_target.insert_user_audio_async.assert_awaited_once() + inserted_pcm = vad_target.insert_user_audio_async.call_args.kwargs["pcm_bytes"] + assert inserted_pcm == bytes((b + 1) & 0xFF for b in raw_chunk) + vad_target.request_response_async.assert_awaited_once() + assert result.executed_turns == 1 + + +async def test_perform_async_skips_swap_when_no_converters(vad_target): + """Empty converter list: don't delete raw, don't insert converted, just request response.""" + attack = BargeInAttack(objective_target=vad_target) # no converter config + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.delete_conversation_item_async = AsyncMock() + vad_target.insert_user_audio_async = AsyncMock() + + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + result_future: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + result_future.set_result(RealtimeTargetResult(audio_bytes=b"", transcripts=[])) + vad_target.request_response_async = AsyncMock(return_value=result_future) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield b"\x00" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_42"))) + + ctx = BargeInAttackContext( + params=AttackParameters(objective="obj"), + audio_chunks=chunks_then_commit(), + ) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + result = await attack._perform_async(context=ctx) + + vad_target.delete_conversation_item_async.assert_not_called() + vad_target.insert_user_audio_async.assert_not_called() + vad_target.request_response_async.assert_awaited_once() + assert result.executed_turns == 1 + + +async def test_perform_async_clears_raw_buffer_between_commits(vad_target): + """A commit must snapshot+reset the raw buffer so the next turn doesn't see prior audio.""" + bump = _make_audio_converter(lambda pcm: bytes((b + 1) & 0xFF for b in pcm)) + attack = BargeInAttack(objective_target=vad_target, attack_converter_config=_converter_config([bump])) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.delete_conversation_item_async = AsyncMock() + vad_target.insert_user_audio_async = AsyncMock() + + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + + def _future_with(result: RealtimeTargetResult) -> asyncio.Future[RealtimeTargetResult]: + fut: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + fut.set_result(result) + return fut + + vad_target.request_response_async = AsyncMock( + side_effect=lambda **_: _future_with(RealtimeTargetResult(audio_bytes=b"", transcripts=[])) + ) + + async def chunks_then_two_commits() -> AsyncIterator[bytes]: + yield b"\x01" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_1"))) + yield b"\x02" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_2"))) + + ctx = BargeInAttackContext( + params=AttackParameters(objective="obj"), + audio_chunks=chunks_then_two_commits(), + ) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + await attack._perform_async(context=ctx) + + insert_calls = vad_target.insert_user_audio_async.await_args_list + assert len(insert_calls) == 2 + assert insert_calls[0].kwargs["pcm_bytes"] == bytes((b + 1) & 0xFF for b in (b"\x01" * 96)) + assert insert_calls[1].kwargs["pcm_bytes"] == bytes((b + 1) & 0xFF for b in (b"\x02" * 96)) + + +async def test_perform_async_uses_target_audio_normalizer(vad_target): + """The attack must delegate audio conversion to the target's audio_normalizer.""" + fake_normalizer = MagicMock(spec=AudioStreamNormalizer) + fake_normalizer.normalize_async = AsyncMock(return_value=(b"\xff" * 96, [])) + vad_target.audio_normalizer = fake_normalizer + attack = BargeInAttack( + objective_target=vad_target, + attack_converter_config=_converter_config([_make_audio_converter(lambda pcm: pcm)]), + ) + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.delete_conversation_item_async = AsyncMock() + vad_target.insert_user_audio_async = AsyncMock() + + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + fut: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + fut.set_result(RealtimeTargetResult(audio_bytes=b"", transcripts=[])) + vad_target.request_response_async = AsyncMock(return_value=fut) + + raw = b"\x05" * 96 + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield raw + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="raw_z"))) + + ctx = BargeInAttackContext( + params=AttackParameters(objective="obj"), + audio_chunks=chunks_then_commit(), + ) + + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + await attack._perform_async(context=ctx) + + fake_normalizer.normalize_async.assert_awaited_once() + kwargs = fake_normalizer.normalize_async.call_args.kwargs + assert kwargs["pcm_bytes"] == raw + assert kwargs["sample_rate"] == 24000 + vad_target.insert_user_audio_async.assert_awaited_once() + assert vad_target.insert_user_audio_async.call_args.kwargs["pcm_bytes"] == b"\xff" * 96 + + +# Placeholder for R4c tests + + +# ---- Per-turn persistence to CentralMemory (R4c) -------------------------------------------- + + +async def _drive_one_audio_turn( + attack, + vad_target, + *, + raw_chunk: bytes, + item_id: str, + turn_result: RealtimeTargetResult, +): + """Helper that runs a single audio-driven turn end-to-end against a mocked target.""" + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.delete_conversation_item_async = AsyncMock() + vad_target.insert_user_audio_async = AsyncMock() + + captured: dict[str, Any] = {} + + async def fake_subscribe(*, connection, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + fut: asyncio.Future[RealtimeTargetResult] = asyncio.get_event_loop().create_future() + fut.set_result(turn_result) + vad_target.request_response_async = AsyncMock(return_value=fut) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield raw_chunk + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id=item_id))) + + ctx = BargeInAttackContext( + params=AttackParameters(objective="obj"), + audio_chunks=chunks_then_commit(), + ) + with patch.object(attack, "_MAX_POST_STREAM_WAIT_SECONDS", 0): + return await attack._perform_async(context=ctx) + + +async def test_persists_user_and_assistant_messages_per_turn(vad_target): + """A successful turn writes 1 user piece + 2 assistant pieces sharing the conversation id.""" + attack = BargeInAttack(objective_target=vad_target) + add_calls: list[Any] = [] + mock_memory = MagicMock() + mock_memory.add_message_to_memory = MagicMock(side_effect=lambda **kw: add_calls.append(kw["request"])) + + with patch("pyrit.executor.attack.streaming.barge_in.CentralMemory") as mock_cm: + mock_cm.get_memory_instance.return_value = mock_memory + result = await _drive_one_audio_turn( + attack, + vad_target, + raw_chunk=b"\x00" * 96, + item_id="raw_1", + turn_result=RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["hello"]), + ) + + assert len(add_calls) == 2 + user_msg, assistant_msg = add_calls + assert len(user_msg.message_pieces) == 1 + assert user_msg.message_pieces[0].converted_value_data_type == "audio_path" + assert user_msg.message_pieces[0].conversation_id == result.conversation_id + assert len(assistant_msg.message_pieces) == 2 + piece_types = sorted(p.converted_value_data_type for p in assistant_msg.message_pieces) + assert piece_types == ["audio_path", "text"] + text_piece = next(p for p in assistant_msg.message_pieces if p.converted_value_data_type == "text") + assert text_piece.converted_value == "hello" + + +async def test_persists_interrupted_metadata_on_assistant_pieces(vad_target): + """Interrupted turns mark both assistant pieces with prompt_metadata['interrupted'] = True.""" + attack = BargeInAttack(objective_target=vad_target) + add_calls: list[Any] = [] + mock_memory = MagicMock() + mock_memory.add_message_to_memory = MagicMock(side_effect=lambda **kw: add_calls.append(kw["request"])) + + with patch("pyrit.executor.attack.streaming.barge_in.CentralMemory") as mock_cm: + mock_cm.get_memory_instance.return_value = mock_memory + await _drive_one_audio_turn( + attack, + vad_target, + raw_chunk=b"\x00" * 96, + item_id="raw_int", + turn_result=RealtimeTargetResult(audio_bytes=b"\xbb" * 96, transcripts=["partial"], interrupted=True), + ) + + assistant_msg = add_calls[1] + for piece in assistant_msg.message_pieces: + assert piece.prompt_metadata.get("interrupted") is True + + +async def test_persists_converter_identifiers_on_user_piece(vad_target): + """Converter identifiers reported by convert_audio_async must land on the user piece.""" + bump = _make_audio_converter( + lambda pcm: bytes((b + 1) & 0xFF for b in pcm), + identifier_name="BumpConverter", + ) + attack = BargeInAttack( + objective_target=vad_target, + attack_converter_config=AttackConverterConfig( + request_converters=PromptConverterConfiguration.from_converters(converters=[bump]), + ), + ) + add_calls: list[Any] = [] + mock_memory = MagicMock() + mock_memory.add_message_to_memory = MagicMock(side_effect=lambda **kw: add_calls.append(kw["request"])) + + with patch("pyrit.executor.attack.streaming.barge_in.CentralMemory") as mock_cm: + mock_cm.get_memory_instance.return_value = mock_memory + await _drive_one_audio_turn( + attack, + vad_target, + raw_chunk=b"\x05" * 96, + item_id="raw_c", + turn_result=RealtimeTargetResult(audio_bytes=b"", transcripts=[]), + ) + + user_msg = add_calls[0] + identifiers = user_msg.message_pieces[0].converter_identifiers + assert len(identifiers) == 1 + assert identifiers[0].class_name == "BumpConverter" + + +async def test_persists_converted_audio_when_converters_changed_bytes(vad_target): + """The user piece's audio_path must point at the converted PCM, not the raw snapshot.""" + bump = _make_audio_converter(lambda pcm: bytes((b + 1) & 0xFF for b in pcm)) + attack = BargeInAttack( + objective_target=vad_target, + attack_converter_config=AttackConverterConfig( + request_converters=PromptConverterConfiguration.from_converters(converters=[bump]), + ), + ) + saved_calls: list[bytes] = [] + + async def fake_save_audio(audio_bytes, **_): + saved_calls.append(audio_bytes) + return f"/tmp/audio_{len(saved_calls)}.wav" + + vad_target.save_audio = AsyncMock(side_effect=fake_save_audio) + mock_memory = MagicMock() + mock_memory.add_message_to_memory = MagicMock() + + raw = b"\x05" * 96 + with patch("pyrit.executor.attack.streaming.barge_in.CentralMemory") as mock_cm: + mock_cm.get_memory_instance.return_value = mock_memory + await _drive_one_audio_turn( + attack, + vad_target, + raw_chunk=raw, + item_id="raw_x", + turn_result=RealtimeTargetResult(audio_bytes=b"\xff" * 96, transcripts=[]), + ) + + # save_audio called twice per turn: first for user audio (must be CONVERTED), then assistant audio. + assert len(saved_calls) == 2 + assert saved_calls[0] == bytes((b + 1) & 0xFF for b in raw) + assert saved_calls[1] == b"\xff" * 96 + + +async def test_attack_result_last_response_is_final_assistant_text_piece(vad_target): + """AttackResult.last_response must point at the last assistant message's first piece (text).""" + attack = BargeInAttack(objective_target=vad_target) + mock_memory = MagicMock() + mock_memory.add_message_to_memory = MagicMock() + + with patch("pyrit.executor.attack.streaming.barge_in.CentralMemory") as mock_cm: + mock_cm.get_memory_instance.return_value = mock_memory + result = await _drive_one_audio_turn( + attack, + vad_target, + raw_chunk=b"\x00" * 96, + item_id="raw_lr", + turn_result=RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["final answer"]), + ) + + assert result.last_response is not None + assert result.last_response.converted_value_data_type == "text" + assert result.last_response.converted_value == "final answer" diff --git a/tests/unit/prompt_normalizer/test_audio_stream_normalizer.py b/tests/unit/prompt_normalizer/test_audio_stream_normalizer.py new file mode 100644 index 0000000000..f48dff68c7 --- /dev/null +++ b/tests/unit/prompt_normalizer/test_audio_stream_normalizer.py @@ -0,0 +1,126 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for ``AudioStreamNormalizer``.""" + +import os +import tempfile +import wave +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.identifiers import ComponentIdentifier +from pyrit.prompt_normalizer import AudioStreamNormalizer +from pyrit.prompt_normalizer.prompt_converter_configuration import ( + PromptConverterConfiguration, +) + + +def _make_audio_converter(transformer, *, output_sample_rate=24000, identifier_name="MockAudioConverter"): + """Mock audio converter whose convert_tokens_async runs transformer(pcm) and emits a new WAV path.""" + converter = MagicMock() + converter.get_identifier = MagicMock( + return_value=ComponentIdentifier(class_name=identifier_name, class_module="tests.unit.mocks"), + ) + + async def _convert(*, prompt, input_type, start_token=None, end_token=None): + assert input_type == "audio_path" + with wave.open(prompt, "rb") as wf_in: + pcm = wf_in.readframes(wf_in.getnframes()) + new_pcm = transformer(pcm) + out_dir = tempfile.mkdtemp() + out_path = os.path.join(out_dir, "out.wav") + with wave.open(out_path, "wb") as wf_out: + wf_out.setnchannels(1) + wf_out.setsampwidth(2) + wf_out.setframerate(output_sample_rate) + wf_out.writeframes(new_pcm) + result = MagicMock() + result.output_text = out_path + return result + + converter.convert_tokens_async = AsyncMock(side_effect=_convert) + return converter + + +async def test_normalize_async_no_configurations_returns_input(): + normalizer = AudioStreamNormalizer() + pcm = b"\xaa" * 1024 + out, ids = await normalizer.normalize_async(pcm_bytes=pcm, sample_rate=24000, converter_configurations=[]) + assert out == pcm + assert ids == [] + + +async def test_normalize_async_empty_pcm_returns_input(): + normalizer = AudioStreamNormalizer() + bump = _make_audio_converter(lambda pcm: pcm) + out, ids = await normalizer.normalize_async( + pcm_bytes=b"", + sample_rate=24000, + converter_configurations=PromptConverterConfiguration.from_converters(converters=[bump]), + ) + assert out == b"" + assert ids == [] + + +async def test_normalize_async_chains_converters_and_returns_identifiers(): + normalizer = AudioStreamNormalizer() + bump_a = _make_audio_converter(lambda pcm: bytes((b + 1) & 0xFF for b in pcm)) + bump_b = _make_audio_converter(lambda pcm: bytes((b + 2) & 0xFF for b in pcm)) + + out, ids = await normalizer.normalize_async( + pcm_bytes=b"\x00\x10\x20\x30", + sample_rate=24000, + converter_configurations=PromptConverterConfiguration.from_converters(converters=[bump_a, bump_b]), + ) + + assert out == b"\x03\x13\x23\x33" + assert len(ids) == 2 # one identifier per converter that ran + + +async def test_normalize_async_respects_data_type_filter(): + """A configuration with prompt_data_types_to_apply not including audio_path must be skipped.""" + normalizer = AudioStreamNormalizer() + skipped = _make_audio_converter(lambda pcm: bytes((b + 9) & 0xFF for b in pcm)) + applied = _make_audio_converter(lambda pcm: bytes((b + 1) & 0xFF for b in pcm)) + + configs = [ + PromptConverterConfiguration(converters=[skipped], prompt_data_types_to_apply=["text"]), + PromptConverterConfiguration(converters=[applied], prompt_data_types_to_apply=["audio_path"]), + ] + out, ids = await normalizer.normalize_async( + pcm_bytes=b"\x00\x10", sample_rate=24000, converter_configurations=configs + ) + + # Only the audio_path-applicable converter ran (+1 not +9). + assert out == b"\x01\x11" + assert len(ids) == 1 + + +async def test_normalize_async_short_circuits_when_all_configs_filtered_out(): + """When every config is text-only, skip the WAV round-trip entirely.""" + normalizer = AudioStreamNormalizer() + text_only = _make_audio_converter(lambda pcm: bytes((b + 9) & 0xFF for b in pcm)) + + configs = [ + PromptConverterConfiguration(converters=[text_only], prompt_data_types_to_apply=["text"]), + ] + pcm = b"\x00\x10\x20\x30" + out, ids = await normalizer.normalize_async(pcm_bytes=pcm, sample_rate=24000, converter_configurations=configs) + + assert out == pcm # bytes unchanged + assert ids == [] + text_only.convert_tokens_async.assert_not_awaited() + + +async def test_normalize_async_rejects_mismatched_sample_rate(): + """Converter output at a different sample rate must raise ValueError.""" + normalizer = AudioStreamNormalizer() + bad = _make_audio_converter(lambda pcm: pcm, output_sample_rate=16000) + with pytest.raises(ValueError, match="incompatible"): + await normalizer.normalize_async( + pcm_bytes=b"\x00" * 1024, + sample_rate=24000, + converter_configurations=PromptConverterConfiguration.from_converters(converters=[bad]), + ) diff --git a/tests/unit/prompt_target/target/test_realtime_audio.py b/tests/unit/prompt_target/target/test_realtime_audio.py new file mode 100644 index 0000000000..005814a5e4 --- /dev/null +++ b/tests/unit/prompt_target/target/test_realtime_audio.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeEventDispatcher, + RealtimeTargetResult, + RealtimeTurnState, +) + + +async def test_realtime_turn_state_defaults(): + """Newly constructed turn state must be empty: no audio, no transcripts, not responding, not interrupted.""" + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + assert state.is_responding is False + assert state.interrupted is False + assert bytes(state.delivered_audio) == b"" + assert state.delivered_transcripts == [] + assert state.current_item_id is None + assert state.last_response_id is None + + +def test_realtime_target_result_interrupted_defaults_false(): + """RealtimeTargetResult must default interrupted=False so atomic callers see no change.""" + result = RealtimeTargetResult() + assert result.interrupted is False + assert result.audio_bytes == b"" + assert result.transcripts == [] + + +def test_realtime_target_result_carries_interrupted_when_set(): + """The interrupted flag round-trips through construction.""" + result = RealtimeTargetResult(audio_bytes=b"partial", transcripts=["hi"], interrupted=True) + assert result.interrupted is True + + +class _RecordingDispatcher(RealtimeEventDispatcher): + """Minimal concrete dispatcher for testing the generic base class behavior.""" + + def __init__(self, *, connection: Any) -> None: + super().__init__(connection=connection) + self.routed_events: list[Any] = [] + self.cancel_calls: int = 0 + + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + self.routed_events.append(event) + # End the turn on a sentinel event so tests can drain the loop. + if state is not None and getattr(event, "_finish", False): + state.completion.set_result(RealtimeTargetResult()) + + async def _cancel(self, *, state: RealtimeTurnState) -> None: + self.cancel_calls += 1 + state.interrupted = True + + +class _ScriptedConnection: + """Async-iterable connection that yields a fixed event list once registered.""" + + def __init__(self, events: list[Any]) -> None: + self._events = events + + async def __aiter__(self): + for event in self._events: + yield event + + +def _sentinel_event(*, finish: bool = False) -> AsyncMock: + event = AsyncMock() + event._finish = finish + return event + + +async def test_dispatcher_start_is_idempotent(): + """Calling start twice must not spawn two tasks.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + await dispatcher.start() + first_task = dispatcher._task + await dispatcher.start() + assert dispatcher._task is first_task + await dispatcher.stop() + + +async def test_dispatcher_stop_releases_task(): + """stop must cancel the task and clear the reference.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + await dispatcher.start() + await dispatcher.stop() + assert dispatcher._task is None + + +async def test_dispatcher_register_turn_rejects_concurrent_active_turn(): + """Registering a turn while another is active and unresolved must raise.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + first = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + second = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + dispatcher.register_turn(first) + with pytest.raises(RuntimeError, match="already active"): + dispatcher.register_turn(second) + + +async def test_dispatcher_register_turn_allows_replacement_after_completion(): + """Once the active turn's future is done, register_turn may bind a new turn.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + first = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + first.completion.set_result(RealtimeTargetResult()) + second = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + dispatcher.register_turn(first) + dispatcher.register_turn(second) + assert dispatcher._current_turn is second + + +async def test_dispatcher_loop_routes_events_to_active_turn(): + """The dispatch loop must forward events from the connection to _route_event.""" + finish = _sentinel_event(finish=True) + other = _sentinel_event() + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([other, finish])) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + dispatcher.register_turn(state) + + await dispatcher.start() + await asyncio.wait_for(state.completion, timeout=1.0) + await dispatcher.stop() + + assert dispatcher.routed_events == [other, finish] + + +async def test_dispatcher_loop_routes_events_with_no_turn_as_state_none(): + """When no turn is registered, events still reach _route_event so input callbacks can fire; state is None.""" + finish = _sentinel_event(finish=True) + other = _sentinel_event() + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([other, finish])) + + # No register_turn called. + await dispatcher.start() + await asyncio.sleep(0.05) + await dispatcher.stop() + + # Both events were routed but no turn was completed (state was None, sentinel branch skipped). + assert dispatcher.routed_events == [other, finish] + + +async def test_dispatcher_loop_sets_exception_on_router_failure(): + """A router exception must propagate to the active turn's completion future.""" + + class _ExplodingDispatcher(_RecordingDispatcher): + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + raise ValueError("router boom") + + event = _sentinel_event() + dispatcher = _ExplodingDispatcher(connection=_ScriptedConnection([event])) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + dispatcher.register_turn(state) + + await dispatcher.start() + with pytest.raises(ValueError, match="router boom"): + await asyncio.wait_for(state.completion, timeout=1.0) + await dispatcher.stop() + + +async def test_dispatcher_fires_committed_callback_as_background_task(): + """The on_user_audio_committed callback must be invoked and awaited via background tasks.""" + + received: list[Any] = [] + blocked = asyncio.Event() + release = asyncio.Event() + + async def slow_callback(event): + received.append(event) + blocked.set() + # Block until the test releases us; this proves the dispatch loop did not wait. + await release.wait() + + class _CallbackDispatcher(RealtimeEventDispatcher): + async def _route_event(self, *, event, state): + # Synthesize a committed callback fire on every event for the test. + self._fire_committed_callback(event) + + async def _cancel(self, *, state): # pragma: no cover - not exercised here + return + + fake_event_1 = MagicMock(spec=CommittedEvent) + fake_event_2 = MagicMock(spec=CommittedEvent) + dispatcher = _CallbackDispatcher( + connection=_ScriptedConnection([fake_event_1, fake_event_2]), + on_user_audio_committed=slow_callback, + ) + + await dispatcher.start() + # Both events should reach the slow callback even though the first is "blocked" awaiting release. + await asyncio.wait_for(blocked.wait(), timeout=1.0) + # Give the loop a tick to process the second event despite the first callback still running. + await asyncio.sleep(0.05) + release.set() + await dispatcher.stop() + + # Both events fired the callback; the loop did not serialize behind the slow first call. + assert len(received) == 2 + + +async def test_dispatcher_records_failure_on_iterator_crash(): + """When the connection iterator raises, the dispatcher's failure property captures the exception.""" + + class _NoopDispatcher(RealtimeEventDispatcher): + async def _route_event(self, *, event, state): # pragma: no cover - never called + return + + async def _cancel(self, *, state): # pragma: no cover + return + + class _ExplodingConnection: + def __aiter__(self): + return self + + async def __anext__(self): + raise RuntimeError("iterator died") + + dispatcher = _NoopDispatcher(connection=_ExplodingConnection()) + await dispatcher.start() + for _ in range(50): + if dispatcher.failure is not None: + break + await asyncio.sleep(0.01) + await dispatcher.stop() + + assert isinstance(dispatcher.failure, RuntimeError) and str(dispatcher.failure) == "iterator died" diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index d0aa9cc5e2..c810587a8e 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -1,14 +1,22 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import AsyncMock, MagicMock, patch +import asyncio +import base64 +from typing import Any +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pyrit.exceptions.exception_classes import ServerErrorException from pyrit.models import Message, MessagePiece -from pyrit.prompt_target import RealtimeTarget -from pyrit.prompt_target.openai.openai_realtime_target import RealtimeTargetResult +from pyrit.prompt_target import RealtimeTarget, ServerVadConfig +from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeTargetResult, + RealtimeTurnState, +) +from pyrit.prompt_target.openai.openai_realtime_target import _OpenAIRealtimeDispatcher # Env vars that may leak from .env files loaded by other tests in parallel workers. _CLEAN_UNDERLYING_MODEL_ENV = { @@ -29,7 +37,7 @@ async def test_connect_success(target): mock_client.realtime.connect.return_value.__aenter__ = AsyncMock(return_value=mock_connection) with patch.object(target, "_get_openai_client", return_value=mock_client): - connection = await target.connect(conversation_id="test_conv") + connection = await target.connect_async(conversation_id="test_conv") assert connection == mock_connection mock_client.realtime.connect.assert_called_once_with(model="test") await target.cleanup_target() @@ -37,7 +45,7 @@ async def test_connect_success(target): async def test_send_prompt_async(target): # Mock the necessary methods - target.connect = AsyncMock(return_value=AsyncMock()) + target.connect_async = AsyncMock(return_value=AsyncMock()) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"file", transcripts=["hello"]) target.send_text_async = AsyncMock(return_value=("output.wav", result)) @@ -70,6 +78,58 @@ async def test_send_prompt_async(target): await target.cleanup_target() +async def test_send_prompt_async_propagates_interrupted_to_metadata(target): + """When a turn result carries interrupted=True, both response pieces' metadata must reflect it.""" + target.connect_async = AsyncMock(return_value=AsyncMock()) + target.send_config = AsyncMock() + interrupted_result = RealtimeTargetResult(audio_bytes=b"partial", transcripts=["hi"], interrupted=True) + target.send_text_async = AsyncMock(return_value=("partial.wav", interrupted_result)) + + message_piece = MessagePiece( + original_value="Hello", + original_value_data_type="text", + converted_value="Hello", + converted_value_data_type="text", + role="user", + conversation_id="test_conv", + ) + message = Message(message_pieces=[message_piece]) + + response = await target.send_prompt_async(message=message) + + text_piece, audio_piece = response[0].message_pieces + assert text_piece.prompt_metadata.get("interrupted") is True + assert audio_piece.prompt_metadata.get("interrupted") is True + + await target.cleanup_target() + + +async def test_send_prompt_async_omits_interrupted_metadata_when_not_set(target): + """A non-interrupted result must not write an interrupted key to MessagePiece metadata.""" + target.connect_async = AsyncMock(return_value=AsyncMock()) + target.send_config = AsyncMock() + normal_result = RealtimeTargetResult(audio_bytes=b"full", transcripts=["hi"]) + target.send_text_async = AsyncMock(return_value=("full.wav", normal_result)) + + message_piece = MessagePiece( + original_value="Hello", + original_value_data_type="text", + converted_value="Hello", + converted_value_data_type="text", + role="user", + conversation_id="test_conv", + ) + message = Message(message_pieces=[message_piece]) + + response = await target.send_prompt_async(message=message) + + text_piece, audio_piece = response[0].message_pieces + assert "interrupted" not in text_piece.prompt_metadata + assert "interrupted" not in audio_piece.prompt_metadata + + await target.cleanup_target() + + async def test_get_system_prompt_from_conversation_with_system_message(target): """Test that system prompt is extracted from conversation history when present.""" @@ -123,7 +183,7 @@ async def test_get_system_prompt_empty_conversation(target): async def test_multiple_websockets_created_for_multiple_conversations(target): # Mock the necessary methods - target.connect = AsyncMock(return_value=AsyncMock()) + target.connect_async = AsyncMock(return_value=AsyncMock()) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"event1", transcripts=["event2"]) target.send_text_async = AsyncMock(return_value=("output_audio_path", result)) @@ -346,7 +406,7 @@ async def test_multi_turn_reuses_connection(target): This ensures that the server-side conversation context is preserved. """ mock_connection = AsyncMock() - target.connect = AsyncMock(return_value=mock_connection) + target.connect_async = AsyncMock(return_value=mock_connection) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"audio", transcripts=["response"]) target.send_text_async = AsyncMock(return_value=("output.wav", result)) @@ -376,7 +436,7 @@ async def test_multi_turn_reuses_connection(target): await target.send_prompt_async(message=Message(message_pieces=[message_piece_2])) # Connection should only be created once for the conversation - target.connect.assert_called_once_with(conversation_id=conversation_id) + target.connect_async.assert_called_once_with(conversation_id=conversation_id) target.send_config.assert_called_once() # Both turns should use the same connection @@ -430,3 +490,566 @@ async def test_receive_events_skips_stale_response_done(target): # Should have processed through to the real response.done with actual audio assert result.audio_bytes == b"dummyaudio" assert result.transcripts == ["hello"] + + +# --------------------------------------------------------------------------- +# Chunk 1 — ServerVadConfig + session config +# --------------------------------------------------------------------------- + + +def test_session_config_omits_turn_detection_when_vad_disabled(target): + """Default construction must not emit a turn_detection block; pins atomic flow.""" + config = target._set_system_prompt_and_config_vars(system_prompt="test prompt") + + assert "turn_detection" not in config["audio"]["input"] + assert config["instructions"] == "test prompt" + + +@patch.dict("os.environ", _CLEAN_UNDERLYING_MODEL_ENV) +def test_session_config_emits_server_vad_block_with_defaults(sqlite_instance): + """server_vad=True must emit defaults.""" + vad_target = RealtimeTarget( + api_key="test_key", + endpoint="wss://test_url", + model_name="test", + server_vad=True, + ) + + config = vad_target._set_system_prompt_and_config_vars(system_prompt="test prompt") + + turn_detection = config["audio"]["input"]["turn_detection"] + assert turn_detection == { + "type": "server_vad", + "threshold": 0.4, + "prefix_padding_ms": 200, + "silence_duration_ms": 1500, + "create_response": True, + "interrupt_response": True, + } + + +@patch.dict("os.environ", _CLEAN_UNDERLYING_MODEL_ENV) +def test_session_config_honors_custom_vad_tuning(sqlite_instance): + """Passing a ServerVadConfig must flow through to the emitted turn_detection block.""" + vad_target = RealtimeTarget( + api_key="test_key", + endpoint="wss://test_url", + model_name="test", + server_vad=ServerVadConfig(threshold=0.7, prefix_padding_ms=350, silence_duration_ms=800), + ) + + turn_detection = vad_target._set_system_prompt_and_config_vars(system_prompt="x")["audio"]["input"][ + "turn_detection" + ] + + assert turn_detection["threshold"] == 0.7 + assert turn_detection["prefix_padding_ms"] == 350 + assert turn_detection["silence_duration_ms"] == 800 + + +@pytest.mark.parametrize( + "kwargs", + [ + {"threshold": -0.1}, + {"threshold": 1.5}, + {"prefix_padding_ms": -1}, + {"silence_duration_ms": -1}, + ], +) +def test_server_vad_config_rejects_invalid_values(kwargs): + """ServerVadConfig must reject out-of-range tuning values at construction.""" + with pytest.raises(ValueError): + ServerVadConfig(**kwargs) + + +# --------------------------------------------------------------------------- +# Chunk 2 — _stream_pcm_async helper +# --------------------------------------------------------------------------- + + +def _make_mock_connection(): + """Return an AsyncMock connection with input_audio_buffer wired up.""" + connection = AsyncMock() + connection.input_audio_buffer.append = AsyncMock() + connection.input_audio_buffer.commit = AsyncMock() + return connection + + +async def test_stream_pcm_even_split_no_commit(target): + """A buffer that divides evenly into chunks emits N appends and no commit when commit=False.""" + connection = _make_mock_connection() + # 100ms @ 24kHz @ 2 bytes/sample = 4800 bytes per chunk. 9600 bytes = 2 chunks. + pcm = b"\x00" * 9600 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + assert connection.input_audio_buffer.append.call_count == 2 + connection.input_audio_buffer.commit.assert_not_called() + + +async def test_stream_pcm_partial_final_chunk(target): + """A buffer not a clean multiple of chunk size sends the final partial chunk as-is.""" + connection = _make_mock_connection() + # 5000 bytes => one full 4800-byte chunk + one 200-byte tail. + pcm = b"\x01" * 5000 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + assert connection.input_audio_buffer.append.call_count == 2 + # Inspect the second call's chunk size by base64-decoding its audio kwarg. + second_call_audio_b64 = connection.input_audio_buffer.append.call_args_list[1].kwargs["audio"] + assert len(base64.b64decode(second_call_audio_b64)) == 200 + + +async def test_stream_pcm_empty_buffer(target): + """An empty buffer yields zero appends. commit=False produces no commit either.""" + connection = _make_mock_connection() + + await target._stream_pcm_async(connection=connection, pcm_bytes=b"", commit=False) + + connection.input_audio_buffer.append.assert_not_called() + connection.input_audio_buffer.commit.assert_not_called() + + +async def test_stream_pcm_commits_when_asked(target): + """commit=True triggers exactly one input_audio_buffer.commit after all appends.""" + connection = _make_mock_connection() + pcm = b"\x02" * 4800 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=True) + + assert connection.input_audio_buffer.append.call_count == 1 + connection.input_audio_buffer.commit.assert_awaited_once_with() + + +async def test_stream_pcm_empty_buffer_still_commits_when_asked(target): + """commit=True with an empty buffer should still fire commit (e.g. to flush an existing buffer).""" + connection = _make_mock_connection() + + await target._stream_pcm_async(connection=connection, pcm_bytes=b"", commit=True) + + connection.input_audio_buffer.append.assert_not_called() + connection.input_audio_buffer.commit.assert_awaited_once_with() + + +async def test_stream_pcm_appends_base64_encoded_chunks(target): + """Each append's audio kwarg must be the base64 encoding of the corresponding PCM chunk.""" + connection = _make_mock_connection() + # Build a recognizable buffer: 4800 bytes of 0xAA then 4800 bytes of 0xBB. + pcm = (b"\xaa" * 4800) + (b"\xbb" * 4800) + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + first_audio = connection.input_audio_buffer.append.call_args_list[0].kwargs["audio"] + second_audio = connection.input_audio_buffer.append.call_args_list[1].kwargs["audio"] + assert base64.b64decode(first_audio) == b"\xaa" * 4800 + assert base64.b64decode(second_audio) == b"\xbb" * 4800 + + +# ---- Wire primitives for streaming attacks --------------------------------------------------- + + +async def test_push_audio_chunk_async_base64_encodes_and_appends(target): + connection = _make_mock_connection() + pcm = b"\x33" * 480 + + await target.push_audio_chunk_async(connection=connection, pcm_bytes=pcm) + + connection.input_audio_buffer.append.assert_awaited_once() + audio_b64 = connection.input_audio_buffer.append.call_args.kwargs["audio"] + assert base64.b64decode(audio_b64) == pcm + + +async def test_push_audio_chunk_async_empty_is_noop(target): + connection = _make_mock_connection() + await target.push_audio_chunk_async(connection=connection, pcm_bytes=b"") + connection.input_audio_buffer.append.assert_not_called() + + +async def test_insert_user_audio_async_creates_input_audio_item(target): + connection = AsyncMock() + pcm = b"\x44" * 480 + + await target.insert_user_audio_async(connection=connection, pcm_bytes=pcm) + + connection.conversation.item.create.assert_awaited_once() + item = connection.conversation.item.create.call_args.kwargs["item"] + assert item["type"] == "message" + assert item["role"] == "user" + assert item["content"][0]["type"] == "input_audio" + assert base64.b64decode(item["content"][0]["audio"]) == pcm + + +async def test_insert_user_text_async_creates_input_text_item(target): + connection = AsyncMock() + + await target.insert_user_text_async(connection=connection, text="hello model") + + connection.conversation.item.create.assert_awaited_once() + item = connection.conversation.item.create.call_args.kwargs["item"] + assert item["role"] == "user" + assert item["content"][0] == {"type": "input_text", "text": "hello model"} + + +async def test_delete_conversation_item_async_forwards_item_id(target): + connection = AsyncMock() + + await target.delete_conversation_item_async(connection=connection, item_id="raw_item_99") + + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_item_99") + + +async def test_swap_user_audio_async_inserts_converted_then_deletes_original(target): + """``swap_user_audio_async`` must insert the converted PCM then delete the original item.""" + connection = AsyncMock() + event = CommittedEvent(item_id="raw_swap_1") + + await target.swap_user_audio_async( + connection=connection, + committed_event=event, + converted_pcm=b"\xab" * 96, + ) + + connection.conversation.item.create.assert_awaited_once() + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_swap_1") + # Insert must precede delete: any future refactor that swaps the order or runs them + # concurrently would corrupt the streaming session — pin the ordering here. + create_index = connection.method_calls.index(call.conversation.item.create(item=ANY)) + delete_index = connection.method_calls.index(call.conversation.item.delete(item_id="raw_swap_1")) + assert create_index < delete_index + + +async def test_swap_user_audio_async_logs_and_swallows_delete_failure(target, caplog): + """Best-effort delete: if ``delete`` raises, ``swap`` logs a warning and returns normally.""" + connection = AsyncMock() + connection.conversation.item.delete.side_effect = RuntimeError("delete blew up") + event = CommittedEvent(item_id="raw_swap_fail") + + with caplog.at_level("WARNING"): + await target.swap_user_audio_async( + connection=connection, + committed_event=event, + converted_pcm=b"\x01" * 96, + ) + + connection.conversation.item.create.assert_awaited_once() + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_swap_fail") + # Even on delete failure, insert must have happened first. + create_index = connection.method_calls.index(call.conversation.item.create(item=ANY)) + delete_index = connection.method_calls.index(call.conversation.item.delete(item_id="raw_swap_fail")) + assert create_index < delete_index + assert any("delete failed for raw_swap_fail" in record.message for record in caplog.records) + + +def _turn_state(*, response_id: str | None = "resp_abc", item_id: str | None = "item_xyz") -> RealtimeTurnState: + """Build a turn state with the named ids preset; completion future is unused by cancel tests.""" + return RealtimeTurnState( + completion=asyncio.get_event_loop().create_future(), + is_responding=True, + last_response_id=response_id, + current_item_id=item_id, + ) + + +def _make_dispatcher(connection): + """Build an _OpenAIRealtimeDispatcher around the given mock connection.""" + return _OpenAIRealtimeDispatcher(connection=connection) + + +async def test_cancel_does_not_send_response_cancel(): + """_cancel must NOT send response.cancel (server auto-cancels on speech detection).""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(response_id="resp_42") + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + connection.response.cancel.assert_not_awaited() + + +async def test_cancel_truncates_to_delivered_audio_ms(): + """Truncate must be called with audio_end_ms computed from delivered_audio length.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(item_id="item_99") + # 4800 delivered bytes / 48 bytes-per-ms = 100ms + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + connection.conversation.item.truncate.assert_awaited_once_with( + item_id="item_99", + content_index=0, + audio_end_ms=100, + ) + assert state.interrupted is True + + +async def test_cancel_only_truncates_no_response_cancel(caplog): + """_cancel must only truncate, not send response.cancel (server handles cancellation).""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(item_id="item_1") + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + assert state.interrupted is True + connection.conversation.item.truncate.assert_awaited_once() + connection.response.cancel.assert_not_awaited() + + +async def test_cancel_marks_interrupted_when_truncate_raises(caplog): + """A failed conversation.item.truncate must log a warning and still flip state.interrupted.""" + connection = AsyncMock() + connection.conversation.item.truncate.side_effect = RuntimeError("boom") + dispatcher = _make_dispatcher(connection) + state = _turn_state() + + await dispatcher._cancel(state=state) + + assert state.interrupted is True + assert any( + "conversation.item.truncate failed" in record.message and record.levelname == "WARNING" + for record in caplog.records + ) + + +def _scripted_event(event_type, **fields): + """Build a MagicMock event with the named type plus any extra attribute paths.""" + event = MagicMock() + event.type = event_type + for path, value in fields.items(): + # Allow dotted attribute paths like "response.id" by walking nested MagicMocks. + parts = path.split(".") + target_attr = event + for part in parts[:-1]: + target_attr = getattr(target_attr, part) + setattr(target_attr, parts[-1], value) + return event + + +async def test_route_event_happy_path_resolves_completion_with_assembled_result(): + """response.created -> output_item.added -> audio.delta -> transcript.delta -> response.done.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("response.created", **{"response.id": "r1"}), state=state) + await dispatcher._route_event(event=_scripted_event("response.output_item.added", **{"item.id": "i1"}), state=state) + await dispatcher._route_event( + event=_scripted_event("response.audio.delta", delta=base64.b64encode(b"\xaa" * 4800).decode("ascii")), + state=state, + ) + await dispatcher._route_event(event=_scripted_event("response.audio_transcript.delta", delta="hello "), state=state) + await dispatcher._route_event(event=_scripted_event("response.audio_transcript.delta", delta="world"), state=state) + await dispatcher._route_event(event=_scripted_event("response.done", **{"response.id": "r1"}), state=state) + + assert state.completion.done() + result = state.completion.result() + assert result.audio_bytes == b"\xaa" * 4800 + assert result.transcripts == ["hello ", "world"] + assert state.interrupted is False + + +async def test_route_event_speech_started_while_responding_cancels_and_resolves_interrupted(): + """speech_started during a response triggers cancel and resolves with interrupted=True.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("response.created", **{"response.id": "r1"}), state=state) + await dispatcher._route_event(event=_scripted_event("response.output_item.added", **{"item.id": "i1"}), state=state) + await dispatcher._route_event( + event=_scripted_event("response.audio.delta", delta=base64.b64encode(b"\xbb" * 2400).decode("ascii")), + state=state, + ) + await dispatcher._route_event(event=_scripted_event("input_audio_buffer.speech_started"), state=state) + + connection.response.cancel.assert_not_awaited() + connection.conversation.item.truncate.assert_awaited_once_with( + item_id="i1", + content_index=0, + audio_end_ms=50, # 2400 / 48 + ) + result = state.completion.result() + assert result.audio_bytes == b"\xbb" * 2400 + assert result.interrupted is True + assert state.interrupted is True + + +async def test_route_event_stale_response_done_after_cancel_is_dropped(): + """A response.done with a stale response_id must not re-resolve a completed future.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + # Pretend a turn just resolved as interrupted on response_id r1. + state.last_response_id = "r1" + state.completion.set_result(RealtimeTargetResult()) + + # Late response.done for r1 arrives; router must not raise InvalidStateError. + await dispatcher._route_event(event=_scripted_event("response.done", **{"response.id": "r1"}), state=state) + + +async def test_route_event_error_resolves_with_exception(): + """error events resolve the completion future via set_exception.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("error", **{"error.message": "rate limited"}), state=state) + + with pytest.raises(RuntimeError, match="rate limited"): + state.completion.result() + + +async def test_route_event_speech_started_without_responding_is_noop(): + """speech_started before a response is in flight does not call cancel or resolve.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("input_audio_buffer.speech_started"), state=state) + + connection.response.cancel.assert_not_awaited() + connection.conversation.item.truncate.assert_not_awaited() + assert not state.completion.done() + assert state.interrupted is False + + +async def test_route_event_committed_event_fires_user_audio_callback(): + """input_audio_buffer.committed must fire the registered on_user_audio_committed callback.""" + connection = AsyncMock() + received: list[Any] = [] + + async def on_committed(event): + received.append(event) + + dispatcher = _OpenAIRealtimeDispatcher(connection=connection, on_user_audio_committed=on_committed) + + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="raw_item_42", audio_start_ms=1234), + state=None, + ) + # Background callback task may not have run yet; yield until it does. + for _ in range(20): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].item_id == "raw_item_42" + assert received[0].audio_start_ms == 1234 + + +async def test_route_event_committed_event_without_callback_is_noop(): + """A committed event with no callback configured must be ignored quietly.""" + connection = AsyncMock() + dispatcher = _OpenAIRealtimeDispatcher(connection=connection) # no callback + + # Must not raise. + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="raw_item_99"), + state=None, + ) + + +# Placeholder for R2 tests + + +# ---- subscribe_events_async + request_response_async (R2) ------------------------------------ + + +async def test_subscribe_events_async_returns_started_dispatcher(target): + """Subscription handle must be a started dispatcher; closing tears the task down.""" + events = [_scripted_event("input_audio_buffer.committed", item_id="i_1")] + + async def event_iter(): + for e in events: + yield e + # Keep the iterator alive briefly so the dispatch task can run. + await asyncio.sleep(0.01) + + connection = MagicMock() + connection.__aiter__ = lambda self_: event_iter() + + received: list[CommittedEvent] = [] + + async def on_committed(event): + received.append(event) + + dispatcher = await target.subscribe_events_async(connection=connection, on_user_audio_committed=on_committed) + try: + # Yield until the dispatch loop processes the scripted event. + for _ in range(20): + if received: + break + await asyncio.sleep(0.01) + assert len(received) == 1 and received[0].item_id == "i_1" + finally: + await dispatcher.stop() + + +async def test_subscribe_events_async_records_loop_failure_on_dispatcher(target): + """A dispatcher loop crash must be reachable via the dispatcher's ``failure`` property.""" + + async def boom_iter(): + raise RuntimeError("loop kaboom") + yield # pragma: no cover # makes it a generator + + connection = MagicMock() + connection.__aiter__ = lambda self_: boom_iter() + + dispatcher = await target.subscribe_events_async(connection=connection) + try: + for _ in range(50): + if dispatcher.failure is not None: + break + await asyncio.sleep(0.01) + assert isinstance(dispatcher.failure, RuntimeError) + finally: + await dispatcher.stop() + + +async def test_request_response_async_registers_turn_and_sends_response_create(target): + """request_response_async must register a fresh turn and call response.create.""" + connection = AsyncMock() + dispatcher = MagicMock() + dispatcher.register_turn = MagicMock() + + future = await target.request_response_async(connection=connection, dispatcher=dispatcher) + + dispatcher.register_turn.assert_called_once() + registered_state = dispatcher.register_turn.call_args.args[0] + assert isinstance(registered_state, RealtimeTurnState) + assert registered_state.completion is future + connection.response.create.assert_awaited_once_with() + + +async def test_request_response_async_future_resolves_with_dispatcher_result(target): + """The future returned by request_response_async resolves when the turn ends.""" + connection = AsyncMock() + dispatcher = MagicMock() + expected_result = RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["ok"]) + + def _register(state): + state.completion.set_result(expected_result) + + dispatcher.register_turn = MagicMock(side_effect=_register) + + future = await target.request_response_async(connection=connection, dispatcher=dispatcher) + result = await future + assert result is expected_result + + +async def test_request_response_async_propagates_register_turn_failure(target): + """If another turn is already pending, register_turn raises and request_response_async surfaces it.""" + connection = AsyncMock() + dispatcher = MagicMock() + dispatcher.register_turn = MagicMock(side_effect=RuntimeError("turn already pending")) + + with pytest.raises(RuntimeError, match="turn already pending"): + await target.request_response_async(connection=connection, dispatcher=dispatcher) + + connection.response.create.assert_not_called()