Skip to content

Commit 3e845dd

Browse files
committed
test: Custom metric support as workaround
ref: google/adk-python#4344
1 parent 1835d83 commit 3e845dd

1 file changed

Lines changed: 192 additions & 1 deletion

File tree

adk-evaluation/tests/test_home_automation_agent.py

Lines changed: 192 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import os.path
5+
from typing import Optional
6+
17
import pytest
2-
from google.adk.evaluation.agent_evaluator import AgentEvaluator
8+
from google.adk.cli.cli_eval import get_default_metric_info
9+
from google.adk.evaluation.agent_evaluator import NUM_RUNS, AgentEvaluator
10+
from google.adk.evaluation.base_eval_service import (
11+
EvalCaseResult,
12+
EvaluateConfig,
13+
EvaluateRequest,
14+
InferenceConfig,
15+
InferenceRequest,
16+
)
17+
from google.adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator
18+
from google.adk.evaluation.eval_config import EvalConfig, get_eval_metrics_from_config
19+
from google.adk.evaluation.eval_metrics import BaseCriterion
20+
from google.adk.evaluation.eval_set import EvalSet
21+
from google.adk.evaluation.local_eval_service import LocalEvalService
22+
from google.adk.evaluation.metric_evaluator_registry import (
23+
_get_default_metric_evaluator_registry,
24+
)
25+
from google.adk.evaluation.simulation.user_simulator_provider import (
26+
UserSimulatorProvider,
27+
)
28+
from google.adk.runners import Aclosing
329

430

531
@pytest.mark.skip(reason="`adk eval` only supports custom metrics")
@@ -10,3 +36,168 @@ async def test_with_single_test_file():
1036
eval_dataset_file_path_or_dir="tests/fixtures/home_automation_agent/simple_test.test.json",
1137
num_runs=1,
1238
)
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_with_single_test_file_workaround():
43+
await CustomMetricsSupportAgentEvaluator.evaluate(
44+
agent_module="home_automation_agent",
45+
eval_dataset_file_path_or_dir="tests/fixtures/home_automation_agent/simple_test.test.json",
46+
num_runs=1,
47+
)
48+
49+
50+
class CustomMetricsSupportAgentEvaluator(AgentEvaluator):
51+
@staticmethod
52+
async def evaluate_eval_set(
53+
agent_module: str,
54+
eval_set: EvalSet,
55+
criteria: Optional[dict[str, float]] = None,
56+
eval_config: Optional[EvalConfig] = None,
57+
num_runs: int = NUM_RUNS,
58+
agent_name: Optional[str] = None,
59+
print_detailed_results: bool = True,
60+
):
61+
"""Evaluates an agent using the given EvalSet with custom metrics."""
62+
if criteria:
63+
base_criteria = {k: BaseCriterion(threshold=v) for k, v in criteria.items()}
64+
eval_config = EvalConfig(criteria=base_criteria)
65+
66+
if eval_config is None:
67+
raise ValueError("`eval_config` is required.")
68+
69+
agent_for_eval = await AgentEvaluator._get_agent_for_eval(
70+
module_name=agent_module, agent_name=agent_name
71+
)
72+
eval_metrics = get_eval_metrics_from_config(eval_config)
73+
74+
user_simulator_provider = UserSimulatorProvider(
75+
user_simulator_config=eval_config.user_simulator_config
76+
)
77+
78+
metric_evaluator_registry = _get_default_metric_evaluator_registry()
79+
if eval_config.custom_metrics:
80+
for metric_name, config in eval_config.custom_metrics.items():
81+
if config.metric_info:
82+
metric_info = config.metric_info.model_copy()
83+
metric_info.metric_name = metric_name
84+
else:
85+
metric_info = get_default_metric_info(
86+
metric_name=metric_name, description=config.description
87+
)
88+
metric_evaluator_registry.register_evaluator(
89+
metric_info, _CustomMetricEvaluator
90+
)
91+
92+
# It is okay to pick up this dummy name.
93+
app_name = "test_app"
94+
eval_service = LocalEvalService(
95+
root_agent=agent_for_eval,
96+
eval_sets_manager=AgentEvaluator._get_eval_sets_manager(
97+
app_name=app_name, eval_set=eval_set
98+
),
99+
user_simulator_provider=user_simulator_provider,
100+
metric_evaluator_registry=metric_evaluator_registry,
101+
)
102+
103+
inference_requests = [
104+
InferenceRequest(
105+
app_name=app_name,
106+
eval_set_id=eval_set.eval_set_id,
107+
inference_config=InferenceConfig(),
108+
)
109+
] * num_runs # Repeat inference request num_runs times.
110+
111+
# Generate inferences
112+
inference_results = []
113+
for inference_request in inference_requests:
114+
async with Aclosing(
115+
eval_service.perform_inference(inference_request=inference_request)
116+
) as agen:
117+
async for inference_result in agen:
118+
inference_results.append(inference_result)
119+
120+
# Evaluate metrics
121+
# As we perform more than one run for an eval case, we collect eval results
122+
# by eval id.
123+
eval_results_by_eval_id: dict[str, list[EvalCaseResult]] = {}
124+
evaluate_request = EvaluateRequest(
125+
inference_results=inference_results,
126+
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
127+
)
128+
async with Aclosing(
129+
eval_service.evaluate(evaluate_request=evaluate_request)
130+
) as agen:
131+
async for eval_result in agen:
132+
eval_id = eval_result.eval_id
133+
if eval_id not in eval_results_by_eval_id:
134+
eval_results_by_eval_id[eval_id] = []
135+
136+
eval_results_by_eval_id[eval_id].append(eval_result)
137+
138+
failures: list[str] = []
139+
140+
for _, eval_results_per_eval_id in eval_results_by_eval_id.items():
141+
eval_metric_results = (
142+
AgentEvaluator._get_eval_metric_results_with_invocation(
143+
eval_results_per_eval_id
144+
)
145+
)
146+
failures_per_eval_case = AgentEvaluator._process_metrics_and_get_failures(
147+
eval_metric_results=eval_metric_results,
148+
print_detailed_results=print_detailed_results,
149+
agent_module=agent_name,
150+
)
151+
152+
failures.extend(failures_per_eval_case)
153+
154+
failure_message = "Following are all the test failures."
155+
if not print_detailed_results:
156+
failure_message += (
157+
" If you looking to get more details on the failures, then please"
158+
" re-run this test with `print_detailed_results` set to `True`."
159+
)
160+
failure_message += "\n" + "\n".join(failures)
161+
assert not failures, failure_message
162+
163+
@staticmethod
164+
async def evaluate(
165+
agent_module: str,
166+
eval_dataset_file_path_or_dir: str,
167+
num_runs: int = NUM_RUNS,
168+
agent_name: Optional[str] = None,
169+
initial_session_file: Optional[str] = None,
170+
print_detailed_results: bool = True,
171+
):
172+
"""Evaluates an Agent given eval data with custom metrics."""
173+
test_files = []
174+
if isinstance(eval_dataset_file_path_or_dir, str) and os.path.isdir(
175+
eval_dataset_file_path_or_dir
176+
):
177+
for root, _, files in os.walk(eval_dataset_file_path_or_dir):
178+
for file in files:
179+
if file.endswith(".test.json"):
180+
test_files.append(os.path.join(root, file))
181+
else:
182+
test_files = [eval_dataset_file_path_or_dir]
183+
184+
initial_session = CustomMetricsSupportAgentEvaluator._get_initial_session(
185+
initial_session_file
186+
)
187+
188+
for test_file in test_files:
189+
eval_config = CustomMetricsSupportAgentEvaluator.find_config_for_test_file(
190+
test_file
191+
)
192+
eval_set = CustomMetricsSupportAgentEvaluator._load_eval_set_from_file(
193+
test_file, eval_config, initial_session
194+
)
195+
196+
await CustomMetricsSupportAgentEvaluator.evaluate_eval_set(
197+
agent_module=agent_module,
198+
eval_set=eval_set,
199+
eval_config=eval_config,
200+
num_runs=num_runs,
201+
agent_name=agent_name,
202+
print_detailed_results=print_detailed_results,
203+
)

0 commit comments

Comments
 (0)