diff --git a/configs/prompts/judge.yaml b/configs/prompts/judge.yaml index 4b2fd36e..5fb0b6d3 100644 --- a/configs/prompts/judge.yaml +++ b/configs/prompts/judge.yaml @@ -829,6 +829,81 @@ judge: ], "explanation": "" }} + s2s_user_prompt: | + You are an expert evaluator checking the **speech clarity and articulation** of entities spoken by an AI voice agent. + + You will receive: + 1. A conversation trace showing what the user said and what data the agent retrieved via tools. Assistant responses are redacted — you must listen to the audio to hear what the agent actually said. + 2. An audio recording of the agent's side of the conversation only (the user is not audible). + + ## Conversation Trace + {conversation_trace_formatted} + + ## IMPORTANT: What This Metric Measures + + This metric measures **speech fidelity** — whether entities are clearly and correctly articulated in the audio. The conversation trace is provided so you know which entities to listen for, NOT so you can judge whether the agent gave the right answer. + + **This is NOT a faithfulness or correctness metric.** Do NOT evaluate: + - Whether the agent used the right entity from a tool response (e.g., agent says "$315" but tool says $300 — this is a faithfulness issue, NOT a speech fidelity issue) + - Whether the agent fabricated or hallucinated information not in the trace + - Whether the agent omitted information it should have mentioned + - Whether the agent's response is logical, helpful, or correct + + **What this metric DOES evaluate:** + When the agent speaks an entity that appears in the conversation trace (user utterances or tool responses), is it **clearly articulated** in the audio? Specifically: + - Can you clearly hear the entity as spoken? + - Does the spoken form sound like the correct entity, or is it garbled, mispronounced, or distorted? + - If the agent spells out a code letter by letter, is each letter/digit clearly distinguishable? + + ## Entity Categories to Listen For + - Confirmation codes (e.g., ZK3FFW, FAR0UM) — especially when spelled out letter by letter + - Flight numbers (e.g., SkyWay 410, SW302) + - Dollar amounts (e.g., $15, $1,285.00) — "fifteen" vs "fifty" matters + - Seat numbers (e.g., 21C, 14A) + - Reference/voucher IDs (e.g., REF-8JVSDF-001) — verify each segment is distinguishable + - Times (e.g., 3:55 PM, 10:30 AM) + - Dates (e.g., March 25th, February 3rd) + - Names (e.g., Mr. Rivera, Rodriguez) + + ## Examples + + **High fidelity (rating = 1):** + - Tool response contains confirmation code "YTM924". Agent says "Y T M nine two four" — each character is clearly audible. ✓ + - User says "last name Patel". Agent says "Patel" — clearly articulated. ✓ + - Tool response says fare is $300. Agent says "$315" — the amount is clearly spoken even though it doesn't match the tool response. This is a faithfulness issue, not a speech fidelity issue. Rate 1. ✓ + - Agent mentions "Dallas" which is not in the tool response — this is a hallucination, not a speech issue. Rate 1. ✓ + + **Low fidelity (rating = 0):** + - Tool response contains "YTM924". Agent tries to spell it out but audio sounds like "Y T N nine two four" — "M" sounds like "N". ✗ + - Agent says a dollar amount but the audio is garbled and you cannot tell if it's "fifty" or "fifteen". ✗ + - Agent spells a code but skips or slurs a letter so the spoken code has fewer characters than expected. ✗ + + **What to ignore (does NOT cause rating = 0):** + - Entities the agent mentions that are NOT in the conversation trace — do not evaluate these + - Minor pronunciation variations that do not change identity (e.g., "Ms." vs "Miss") + - Filler words, phrasing, word choice, sentence structure + - Slight pacing or prosody differences + + ## Rating Scale (per turn) + - **1 (High Fidelity)**: Every entity from the conversation trace that the agent speaks in this turn is clearly and correctly articulated. + - **0 (Low Fidelity)**: One or more entities from the conversation trace are garbled, mispronounced, or indistinguishable in the audio. + + If the assistant does not speak any entities from the conversation trace in a turn (e.g., a greeting, filler, or turn where it only mentions entities not in the trace), set `has_entities` to false. These turns are excluded from scoring. + + ## Response Format + Respond with a JSON object. Each turn entry must include the turn_id matching the turn number shown in the Conversation Trace above: + {{ + "turns": [ + {{ + "turn_id": , + "transcript": , + "has_entities": , + "explanation": "", + "rating": <0 or 1> + }} + ], + "explanation": "" + }} user_speech_fidelity: user_prompt: | diff --git a/src/eva/metrics/accuracy/__init__.py b/src/eva/metrics/accuracy/__init__.py index 35e21ae4..ab2fff74 100644 --- a/src/eva/metrics/accuracy/__init__.py +++ b/src/eva/metrics/accuracy/__init__.py @@ -1,11 +1,13 @@ """Task completion metrics - measuring whether the agent accomplished the user's goal.""" from . import agent_speech_fidelity # noqa +from . import agent_speech_fidelity_s2s # noqa from . import faithfulness # noqa from . import task_completion # noqa __all__ = [ "agent_speech_fidelity", + "agent_speech_fidelity_s2s", "faithfulness", "task_completion", ] diff --git a/src/eva/metrics/accuracy/agent_speech_fidelity_s2s.py b/src/eva/metrics/accuracy/agent_speech_fidelity_s2s.py new file mode 100644 index 00000000..c3ed2858 --- /dev/null +++ b/src/eva/metrics/accuracy/agent_speech_fidelity_s2s.py @@ -0,0 +1,239 @@ +"""Agent speech fidelity metric for S2S models — entity-focused evaluation. + +For S2S (speech-to-speech) models, there is no intended text to compare against. +Instead, this metric verifies that key entities spoken by the agent (from tool +responses and user utterances) are accurate by sending a redacted conversation +trace alongside the agent audio to Gemini. +""" + +import json +from typing import Any + +from eva.metrics.base import MetricContext +from eva.metrics.speech_fidelity_base import SpeechFidelityBaseMetric +from eva.metrics.utils import aggregate_per_turn_scores, normalize_rating, resolve_turn_id +from eva.models.results import MetricScore + + +class AgentSpeechFidelityS2SMetric(SpeechFidelityBaseMetric): + """Audio-based entity fidelity metric for S2S agent speech. + + Evaluates whether key entities (from tool responses and user utterances) are + spoken correctly by the agent, without requiring intended text. + + Rating scale: 0 (entity error) or 1 (all entities accurate) + """ + + name = "agent_speech_fidelity" + description = "Audio-based evaluation of agent entity fidelity for S2S models" + category = "accuracy" + role = "assistant" + rating_scale = (0, 1) + pass_at_k_threshold = 0.95 + + async def compute(self, context: MetricContext) -> MetricScore: + """Compute entity fidelity score using redacted conversation trace + audio.""" + try: + audio_segment = self.load_role_audio(context, self.role) + if audio_segment is None: + return MetricScore( + name=self.name, + score=0.0, + normalized_score=0.0, + error=f"No {self.role} audio file available", + ) + + redacted_trace = self._build_redacted_trace(context) + assistant_turn_ids = self._get_assistant_turn_ids(redacted_trace) + + if not assistant_turn_ids: + return MetricScore( + name=self.name, + score=0.0, + normalized_score=0.0, + error="No assistant turns found in conversation trace", + ) + + num_turns = len(assistant_turn_ids) + trace_formatted = self._format_redacted_trace(redacted_trace) + audio_b64 = self.encode_audio_segment(audio_segment) + + prompt = self.get_judge_prompt( + prompt_key="s2s_user_prompt", + conversation_trace_formatted=trace_formatted, + ) + + messages = self.create_audio_message(audio_b64, prompt) + + per_turn_ratings: dict[int, int | None] = {} + per_turn_explanations: dict[int, str] = {} + per_turn_transcripts: dict[int, str] = {} + per_turn_normalized: dict[int, float] = {} + min_rating, max_rating = self.rating_scale + valid_ratings_range = list(range(min_rating, max_rating + 1)) + + response_text, turns = await self._call_and_parse(messages, context, audio_segment, prompt) + + if response_text is None: + return MetricScore( + name=self.name, + score=0.0, + normalized_score=0.0, + error="No response from judge", + ) + + self.logger.debug(f"Raw judge response: {response_text[:200]}") + + if len(turns) != num_turns: + self.logger.warning( + f"[{context.record_id}] Expected {num_turns} ratings for S2S entity fidelity, got {len(turns)}" + ) + + per_turn_has_entities: dict[int, bool] = {} + + for response_item in turns: + turn_id = resolve_turn_id(response_item, assistant_turn_ids, self.name) + if turn_id is None: + continue + rating = response_item.get("rating") + transcript = response_item.get("transcript") + explanation = response_item.get("explanation", "") + has_entities = response_item.get("has_entities", True) + + per_turn_has_entities[turn_id] = has_entities + + if not has_entities: + # Exclude turns with no entities from scoring + per_turn_ratings[turn_id] = rating + per_turn_explanations[turn_id] = explanation + per_turn_transcripts[turn_id] = transcript + continue + + if rating not in valid_ratings_range: + self.logger.warning(f"[{context.record_id}] Invalid rating {rating} for turn {turn_id}") + per_turn_ratings[turn_id] = None + per_turn_explanations[turn_id] = f"Invalid rating: {rating}" + continue + + per_turn_ratings[turn_id] = rating + per_turn_explanations[turn_id] = explanation + per_turn_transcripts[turn_id] = transcript + per_turn_normalized[turn_id] = normalize_rating(rating, min_rating, max_rating) + + aggregated_score = aggregate_per_turn_scores(list(per_turn_normalized.values()), self.aggregation) + + # Only count turns with entities toward the score + valid_ratings = [ + per_turn_ratings[tid] + for tid in per_turn_ratings + if per_turn_ratings[tid] is not None and per_turn_has_entities.get(tid, True) + ] + avg_rating = sum(valid_ratings) / len(valid_ratings) if valid_ratings else 0.0 + num_skipped_no_entities = sum(1 for v in per_turn_has_entities.values() if not v) + + details: dict[str, Any] = { + "variant": "s2s", + "aggregation": self.aggregation, + "num_turns": num_turns, + "num_evaluated": len(valid_ratings), + "num_skipped_no_entities": num_skipped_no_entities, + "per_turn_ratings": per_turn_ratings, + "per_turn_has_entities": per_turn_has_entities, + "per_turn_explanations": per_turn_explanations, + "judge_prompt": prompt, + "judge_raw_response": response_text, + } + + return MetricScore( + name=self.name, + score=round(avg_rating, 3), + normalized_score=round(aggregated_score, 3) if aggregated_score is not None else 0, + details=details, + error="Aggregation failed" if aggregated_score is None else None, + ) + + except Exception as e: + return self._handle_error(e, context) + + @staticmethod + def _build_redacted_trace(context: MetricContext) -> list[dict]: + """Build a redacted conversation trace for entity fidelity evaluation. + + Keeps user entries and tool responses as-is (entity sources). + Replaces assistant entries with a single placeholder per turn_id + (a turn can have multiple assistant entries, e.g. before/after tool calls). + Drops tool_call entries (parameters, not entity sources). + + Note: conversation trace entries use different schemas by type: + - user/assistant entries have ``role`` + ``content`` + - tool entries have ``type`` (tool_call/tool_response) + ``tool_name`` + data fields + """ + redacted = [] + seen_assistant_turns: set[int] = set() + for entry in context.conversation_trace or []: + role = entry.get("role") + entry_type = entry.get("type") + + if role == "assistant": + turn_id = entry.get("turn_id") + if turn_id not in seen_assistant_turns: + seen_assistant_turns.add(turn_id) + redacted.append( + { + "role": "assistant", + "turn_id": turn_id, + "redacted": True, + } + ) + elif role == "user": + redacted.append( + { + "role": "user", + "content": entry.get("content", ""), + "turn_id": entry.get("turn_id"), + } + ) + elif entry_type == "tool_response": + redacted.append( + { + "role": "tool_response", + "tool_name": entry.get("tool_name", "unknown"), + "content": entry.get("tool_response", {}), + "turn_id": entry.get("turn_id"), + } + ) + # Skip tool_call entries — parameters are not entity sources + + return redacted + + @staticmethod + def _get_assistant_turn_ids(redacted_trace: list[dict]) -> list[int]: + """Extract sorted unique assistant turn IDs from the redacted trace.""" + turn_ids = set() + for entry in redacted_trace: + if entry.get("role") == "assistant" and entry.get("turn_id") is not None: + turn_ids.add(entry["turn_id"]) + return sorted(turn_ids) + + @staticmethod + def _format_redacted_trace(redacted_trace: list[dict]) -> str: + """Format the redacted trace as text for the prompt.""" + lines = [] + for entry in redacted_trace: + turn_id = entry.get("turn_id", "?") + role = entry["role"] + + if role == "user": + lines.append(f"Turn {turn_id} - User: {entry['content']}") + elif role == "assistant": + lines.append(f"Turn {turn_id} - [Assistant speaks]") + elif role == "tool_response": + tool_name = entry.get("tool_name", "unknown") + content = entry.get("content", {}) + if isinstance(content, (dict, list)): + content_str = json.dumps(content, indent=None) + else: + content_str = str(content) + lines.append(f"Turn {turn_id} - Tool Response ({tool_name}): {content_str}") + + return "\n".join(lines) diff --git a/src/eva/metrics/runner.py b/src/eva/metrics/runner.py index 289ce8cb..cce0ab2e 100644 --- a/src/eva/metrics/runner.py +++ b/src/eva/metrics/runner.py @@ -9,6 +9,7 @@ import yaml +from eva.metrics.accuracy.agent_speech_fidelity_s2s import AgentSpeechFidelityS2SMetric from eva.metrics.aggregation import compute_record_aggregates, compute_run_level_aggregates from eva.metrics.base import BaseMetric, MetricContext from eva.metrics.processor import MetricsContextProcessor @@ -118,6 +119,13 @@ def __init__( else: logger.warning(f"Metric '{name}' not found, skipping") + # For S2S pipelines, swap agent_speech_fidelity with entity-focused variant + if self._pipeline_type == PipelineType.S2S: + self.metrics = [ + AgentSpeechFidelityS2SMetric(config=m.config) if m.name == "agent_speech_fidelity" else m + for m in self.metrics + ] + logger.info(f"Metrics runner initialized with {len(self.metrics)} metrics") def _load_agent_config(self) -> dict[str, Any]: diff --git a/tests/unit/metrics/test_speech_fidelity_s2s.py b/tests/unit/metrics/test_speech_fidelity_s2s.py new file mode 100644 index 00000000..25fbacd3 --- /dev/null +++ b/tests/unit/metrics/test_speech_fidelity_s2s.py @@ -0,0 +1,381 @@ +"""Tests for agent_speech_fidelity S2S variant (entity-focused evaluation).""" + +import json +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from eva.metrics.accuracy.agent_speech_fidelity_s2s import AgentSpeechFidelityS2SMetric +from eva.models.config import PipelineType + +from .conftest import make_judge_metric, make_metric_context + + +def make_judge_response(turns: list[dict]) -> str: + """Create a JSON judge response with a ``turns`` wrapper.""" + return json.dumps({"turns": turns}) + + +@pytest.fixture +def s2s_metric(): + return make_judge_metric( + AgentSpeechFidelityS2SMetric, + mock_llm=True, + logger_name="test_agent_speech_fidelity_s2s", + ) + + +# --- Sample conversation traces --- + +# Conversation trace entries use different schemas: +# - user/assistant: have "role" + "content" + "type" (intended/transcribed) +# - tool entries: have "type" (tool_call/tool_response) + "tool_name" + data fields, no "role" + +SIMPLE_TRACE = [ + {"role": "user", "content": "Check reservation ABC123, last name Smith", "type": "intended", "turn_id": 0}, + {"role": "assistant", "content": "Looking that up for you.", "type": "transcribed", "turn_id": 1}, + { + "tool_name": "get_reservation", + "parameters": {"confirmation_number": "ABC123"}, + "type": "tool_call", + "turn_id": 1, + }, + { + "tool_name": "get_reservation", + "tool_response": {"confirmation_number": "ABC123", "last_name": "Smith", "flight": "UA456"}, + "type": "tool_response", + "turn_id": 1, + }, + {"role": "assistant", "content": "Your flight is UA456.", "type": "transcribed", "turn_id": 1}, + {"role": "user", "content": "Thanks", "type": "intended", "turn_id": 2}, + {"role": "assistant", "content": "You're welcome!", "type": "transcribed", "turn_id": 3}, +] + +MULTI_ASSISTANT_SAME_TURN_TRACE = [ + {"role": "user", "content": "Book me a flight", "type": "intended", "turn_id": 0}, + {"role": "assistant", "content": "Let me search.", "type": "transcribed", "turn_id": 1}, + {"tool_name": "search_flights", "parameters": {}, "type": "tool_call", "turn_id": 1}, + {"tool_name": "search_flights", "tool_response": {"flights": ["SW302"]}, "type": "tool_response", "turn_id": 1}, + {"role": "assistant", "content": "I found flight SW302.", "type": "transcribed", "turn_id": 1}, + {"role": "user", "content": "Great, book it", "type": "intended", "turn_id": 2}, + {"role": "assistant", "content": "Done!", "type": "transcribed", "turn_id": 3}, +] + +NO_TOOL_TRACE = [ + {"role": "user", "content": "Hello", "type": "intended", "turn_id": 0}, + {"role": "assistant", "content": "Hi there!", "type": "transcribed", "turn_id": 1}, +] + + +def _default_context(**overrides): + """Context for S2S speech fidelity tests.""" + defaults = { + "audio_assistant_path": "/fake/audio_assistant.wav", + "audio_user_path": "/fake/audio_user.wav", + "pipeline_type": PipelineType.S2S, + "conversation_trace": SIMPLE_TRACE, + } + defaults.update(overrides) + return make_metric_context(**defaults) + + +class TestClassAttributes: + def test_s2s_metric_attributes(self, s2s_metric): + assert s2s_metric.name == "agent_speech_fidelity" + assert s2s_metric.category == "accuracy" + assert s2s_metric.role == "assistant" + assert s2s_metric.rating_scale == (0, 1) + assert s2s_metric.pass_at_k_threshold == 0.95 + + +class TestBuildRedactedTrace: + def test_assistant_entries_are_redacted(self, s2s_metric): + redacted = s2s_metric._build_redacted_trace(_default_context()) + assistant_entries = [e for e in redacted if e["role"] == "assistant"] + for entry in assistant_entries: + assert entry.get("redacted") is True + assert "content" not in entry + + def test_user_entries_preserved(self, s2s_metric): + redacted = s2s_metric._build_redacted_trace(_default_context()) + user_entries = [e for e in redacted if e["role"] == "user"] + assert len(user_entries) == 2 + assert user_entries[0]["content"] == "Check reservation ABC123, last name Smith" + assert user_entries[1]["content"] == "Thanks" + + def test_tool_responses_preserved(self, s2s_metric): + redacted = s2s_metric._build_redacted_trace(_default_context()) + tool_entries = [e for e in redacted if e["role"] == "tool_response"] + assert len(tool_entries) == 1 + assert tool_entries[0]["tool_name"] == "get_reservation" + assert tool_entries[0]["content"]["confirmation_number"] == "ABC123" + assert tool_entries[0]["content"]["flight"] == "UA456" + + def test_tool_calls_dropped(self, s2s_metric): + """Tool call entries (type=tool_call, no role) should not appear in redacted trace.""" + redacted = s2s_metric._build_redacted_trace(_default_context()) + tool_call_entries = [e for e in redacted if e.get("type") == "tool_call" or e.get("role") == "tool_call"] + assert len(tool_call_entries) == 0 + + def test_multiple_assistant_entries_same_turn_deduplicated(self, s2s_metric): + """Multiple assistant entries in the same turn should produce one placeholder.""" + context = _default_context(conversation_trace=MULTI_ASSISTANT_SAME_TURN_TRACE) + redacted = s2s_metric._build_redacted_trace(context) + assistant_entries = [e for e in redacted if e["role"] == "assistant"] + # Turn 1 has two assistant entries, but should be deduplicated to one + turn_1_entries = [e for e in assistant_entries if e["turn_id"] == 1] + assert len(turn_1_entries) == 1 + + def test_empty_trace(self, s2s_metric): + context = _default_context(conversation_trace=[]) + redacted = s2s_metric._build_redacted_trace(context) + assert redacted == [] + + def test_none_trace(self, s2s_metric): + context = _default_context(conversation_trace=None) + redacted = s2s_metric._build_redacted_trace(context) + assert redacted == [] + + +class TestGetAssistantTurnIds: + def test_extracts_unique_turn_ids(self, s2s_metric): + redacted = s2s_metric._build_redacted_trace(_default_context()) + turn_ids = s2s_metric._get_assistant_turn_ids(redacted) + assert turn_ids == [1, 3] + + def test_deduplicates_same_turn(self, s2s_metric): + context = _default_context(conversation_trace=MULTI_ASSISTANT_SAME_TURN_TRACE) + redacted = s2s_metric._build_redacted_trace(context) + turn_ids = s2s_metric._get_assistant_turn_ids(redacted) + assert turn_ids == [1, 3] + + def test_empty_trace(self, s2s_metric): + turn_ids = s2s_metric._get_assistant_turn_ids([]) + assert turn_ids == [] + + +class TestFormatRedactedTrace: + def test_format_simple_trace(self, s2s_metric): + redacted = s2s_metric._build_redacted_trace(_default_context()) + formatted = s2s_metric._format_redacted_trace(redacted) + lines = formatted.split("\n") + + assert lines[0] == "Turn 0 - User: Check reservation ABC123, last name Smith" + assert lines[1] == "Turn 1 - [Assistant speaks]" + assert "Turn 1 - Tool Response (get_reservation):" in lines[2] + assert '"confirmation_number": "ABC123"' in lines[2] + assert lines[3] == "Turn 2 - User: Thanks" + assert lines[4] == "Turn 3 - [Assistant speaks]" + + def test_format_no_duplicate_assistant_lines(self, s2s_metric): + """Even with multiple assistant entries per turn, only one line appears.""" + context = _default_context(conversation_trace=MULTI_ASSISTANT_SAME_TURN_TRACE) + redacted = s2s_metric._build_redacted_trace(context) + formatted = s2s_metric._format_redacted_trace(redacted) + assert formatted.count("Turn 1 - [Assistant speaks]") == 1 + + +class TestNoAudio: + @pytest.mark.asyncio + async def test_no_audio_returns_error(self, s2s_metric): + context = _default_context(audio_assistant_path=None) + result = await s2s_metric.compute(context) + assert result.score == 0.0 + assert "No assistant audio" in result.error + + +class TestNoAssistantTurns: + @pytest.mark.asyncio + async def test_no_assistant_turns_returns_error(self, s2s_metric): + trace = [ + {"role": "user", "content": "Hello", "type": "intended", "turn_id": 0}, + ] + context = _default_context(conversation_trace=trace) + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + result = await s2s_metric.compute(context) + assert result.score == 0.0 + assert "No assistant turns" in result.error + + +class TestNoJudgeResponse: + @pytest.mark.asyncio + async def test_no_response_returns_error(self, s2s_metric): + s2s_metric.llm_client.generate_text.return_value = None + context = _default_context() + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + result = await s2s_metric.compute(context) + assert result.score == 0.0 + assert result.error == "No response from judge" + + +class TestS2SCompute: + @pytest.mark.asyncio + async def test_all_high_fidelity(self, s2s_metric): + """All turns rated 1 -> perfect score.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "explanation": "All entities correct"}, + {"turn_id": 3, "rating": 1, "explanation": "No entities to check"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.score == 1.0 + assert result.normalized_score == 1.0 + assert result.details["num_turns"] == 2 + assert result.details["num_evaluated"] == 2 + assert result.details["variant"] == "s2s" + assert result.error is None + + @pytest.mark.asyncio + async def test_all_low_fidelity(self, s2s_metric): + """All turns rated 0 -> zero score.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 0, "explanation": "Said UA465 instead of UA456"}, + {"turn_id": 3, "rating": 0, "explanation": "Wrong name"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.score == 0.0 + assert result.normalized_score == 0.0 + + @pytest.mark.asyncio + async def test_mixed_ratings(self, s2s_metric): + """One turn correct, one incorrect -> 0.5.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "explanation": "Correct"}, + {"turn_id": 3, "rating": 0, "explanation": "Wrong entity"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.score == 0.5 + assert result.normalized_score == 0.5 + + @pytest.mark.asyncio + async def test_invalid_rating_excluded(self, s2s_metric): + """Invalid ratings are excluded from aggregation.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "has_entities": True, "explanation": "Good"}, + {"turn_id": 3, "rating": 5, "has_entities": True, "explanation": "Invalid"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.details["num_evaluated"] == 1 + assert result.details["per_turn_ratings"][3] is None + assert result.score == 1.0 + + @pytest.mark.asyncio + async def test_no_entity_turns_excluded_from_score(self, s2s_metric): + """Turns with has_entities=false should not count toward the score.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 0, "has_entities": False, "explanation": "Greeting, no entities"}, + {"turn_id": 3, "rating": 1, "has_entities": True, "explanation": "Flight number correct"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + # Only turn 3 (has_entities=True) should be evaluated + assert result.details["num_evaluated"] == 1 + assert result.details["num_skipped_no_entities"] == 1 + assert result.score == 1.0 + assert result.normalized_score == 1.0 + + @pytest.mark.asyncio + async def test_all_turns_no_entities(self, s2s_metric): + """If all turns have no entities, score should be 0 with no evaluated turns.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "has_entities": False, "explanation": "No entities"}, + {"turn_id": 3, "rating": 1, "has_entities": False, "explanation": "No entities"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.details["num_evaluated"] == 0 + assert result.details["num_skipped_no_entities"] == 2 + assert result.score == 0.0 + + @pytest.mark.asyncio + async def test_has_entities_defaults_to_true(self, s2s_metric): + """If has_entities is missing from response, default to True (include in scoring).""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "explanation": "Good"}, + {"turn_id": 3, "rating": 0, "explanation": "Wrong entity"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.details["num_evaluated"] == 2 + assert result.details["num_skipped_no_entities"] == 0 + assert result.score == 0.5 + + +class TestTurnCountMismatch: + @pytest.mark.asyncio + async def test_fewer_turns_returned(self, s2s_metric, caplog): + """Fewer turns than expected logs a warning but still computes.""" + response = make_judge_response( + [ + {"turn_id": 1, "rating": 1, "explanation": "Good"}, + ] + ) + s2s_metric.llm_client.generate_text.return_value = response + with patch.object(s2s_metric, "load_role_audio", return_value=MagicMock()): + with patch.object(s2s_metric, "encode_audio_segment", return_value="base64audio"): + context = _default_context() + with caplog.at_level(logging.WARNING): + result = await s2s_metric.compute(context) + + assert "Expected 2 ratings" in caplog.text + assert result.details["num_evaluated"] == 1 + assert result.score == 1.0 + + +class TestErrorHandling: + @pytest.mark.asyncio + async def test_exception_returns_error_score(self, s2s_metric): + with patch.object(s2s_metric, "load_role_audio", side_effect=RuntimeError("boom")): + context = _default_context() + result = await s2s_metric.compute(context) + + assert result.score == 0.0 + assert result.normalized_score == 0.0 + assert "boom" in result.error