1+ from __future__ import annotations
2+
3+ import os
4+ import os .path
5+ from typing import Optional
6+
17import pytest
2- from google .adk .evaluation .agent_evaluator import AgentEvaluator
8+ from google .adk .cli .cli_eval import get_default_metric_info
9+ from google .adk .evaluation .agent_evaluator import NUM_RUNS , AgentEvaluator
10+ from google .adk .evaluation .base_eval_service import (
11+ EvalCaseResult ,
12+ EvaluateConfig ,
13+ EvaluateRequest ,
14+ InferenceConfig ,
15+ InferenceRequest ,
16+ )
17+ from google .adk .evaluation .custom_metric_evaluator import _CustomMetricEvaluator
18+ from google .adk .evaluation .eval_config import EvalConfig , get_eval_metrics_from_config
19+ from google .adk .evaluation .eval_metrics import BaseCriterion
20+ from google .adk .evaluation .eval_set import EvalSet
21+ from google .adk .evaluation .local_eval_service import LocalEvalService
22+ from google .adk .evaluation .metric_evaluator_registry import (
23+ _get_default_metric_evaluator_registry ,
24+ )
25+ from google .adk .evaluation .simulation .user_simulator_provider import (
26+ UserSimulatorProvider ,
27+ )
28+ from google .adk .runners import Aclosing
329
430
531@pytest .mark .skip (reason = "`adk eval` only supports custom metrics" )
@@ -10,3 +36,168 @@ async def test_with_single_test_file():
1036 eval_dataset_file_path_or_dir = "tests/fixtures/home_automation_agent/simple_test.test.json" ,
1137 num_runs = 1 ,
1238 )
39+
40+
41+ @pytest .mark .asyncio
42+ async def test_with_single_test_file_workaround ():
43+ await CustomMetricsSupportAgentEvaluator .evaluate (
44+ agent_module = "home_automation_agent" ,
45+ eval_dataset_file_path_or_dir = "tests/fixtures/home_automation_agent/simple_test.test.json" ,
46+ num_runs = 1 ,
47+ )
48+
49+
50+ class CustomMetricsSupportAgentEvaluator (AgentEvaluator ):
51+ @staticmethod
52+ async def evaluate_eval_set (
53+ agent_module : str ,
54+ eval_set : EvalSet ,
55+ criteria : Optional [dict [str , float ]] = None ,
56+ eval_config : Optional [EvalConfig ] = None ,
57+ num_runs : int = NUM_RUNS ,
58+ agent_name : Optional [str ] = None ,
59+ print_detailed_results : bool = True ,
60+ ):
61+ """Evaluates an agent using the given EvalSet with custom metrics."""
62+ if criteria :
63+ base_criteria = {k : BaseCriterion (threshold = v ) for k , v in criteria .items ()}
64+ eval_config = EvalConfig (criteria = base_criteria )
65+
66+ if eval_config is None :
67+ raise ValueError ("`eval_config` is required." )
68+
69+ agent_for_eval = await AgentEvaluator ._get_agent_for_eval (
70+ module_name = agent_module , agent_name = agent_name
71+ )
72+ eval_metrics = get_eval_metrics_from_config (eval_config )
73+
74+ user_simulator_provider = UserSimulatorProvider (
75+ user_simulator_config = eval_config .user_simulator_config
76+ )
77+
78+ metric_evaluator_registry = _get_default_metric_evaluator_registry ()
79+ if eval_config .custom_metrics :
80+ for metric_name , config in eval_config .custom_metrics .items ():
81+ if config .metric_info :
82+ metric_info = config .metric_info .model_copy ()
83+ metric_info .metric_name = metric_name
84+ else :
85+ metric_info = get_default_metric_info (
86+ metric_name = metric_name , description = config .description
87+ )
88+ metric_evaluator_registry .register_evaluator (
89+ metric_info , _CustomMetricEvaluator
90+ )
91+
92+ # It is okay to pick up this dummy name.
93+ app_name = "test_app"
94+ eval_service = LocalEvalService (
95+ root_agent = agent_for_eval ,
96+ eval_sets_manager = AgentEvaluator ._get_eval_sets_manager (
97+ app_name = app_name , eval_set = eval_set
98+ ),
99+ user_simulator_provider = user_simulator_provider ,
100+ metric_evaluator_registry = metric_evaluator_registry ,
101+ )
102+
103+ inference_requests = [
104+ InferenceRequest (
105+ app_name = app_name ,
106+ eval_set_id = eval_set .eval_set_id ,
107+ inference_config = InferenceConfig (),
108+ )
109+ ] * num_runs # Repeat inference request num_runs times.
110+
111+ # Generate inferences
112+ inference_results = []
113+ for inference_request in inference_requests :
114+ async with Aclosing (
115+ eval_service .perform_inference (inference_request = inference_request )
116+ ) as agen :
117+ async for inference_result in agen :
118+ inference_results .append (inference_result )
119+
120+ # Evaluate metrics
121+ # As we perform more than one run for an eval case, we collect eval results
122+ # by eval id.
123+ eval_results_by_eval_id : dict [str , list [EvalCaseResult ]] = {}
124+ evaluate_request = EvaluateRequest (
125+ inference_results = inference_results ,
126+ evaluate_config = EvaluateConfig (eval_metrics = eval_metrics ),
127+ )
128+ async with Aclosing (
129+ eval_service .evaluate (evaluate_request = evaluate_request )
130+ ) as agen :
131+ async for eval_result in agen :
132+ eval_id = eval_result .eval_id
133+ if eval_id not in eval_results_by_eval_id :
134+ eval_results_by_eval_id [eval_id ] = []
135+
136+ eval_results_by_eval_id [eval_id ].append (eval_result )
137+
138+ failures : list [str ] = []
139+
140+ for _ , eval_results_per_eval_id in eval_results_by_eval_id .items ():
141+ eval_metric_results = (
142+ AgentEvaluator ._get_eval_metric_results_with_invocation (
143+ eval_results_per_eval_id
144+ )
145+ )
146+ failures_per_eval_case = AgentEvaluator ._process_metrics_and_get_failures (
147+ eval_metric_results = eval_metric_results ,
148+ print_detailed_results = print_detailed_results ,
149+ agent_module = agent_name ,
150+ )
151+
152+ failures .extend (failures_per_eval_case )
153+
154+ failure_message = "Following are all the test failures."
155+ if not print_detailed_results :
156+ failure_message += (
157+ " If you looking to get more details on the failures, then please"
158+ " re-run this test with `print_detailed_results` set to `True`."
159+ )
160+ failure_message += "\n " + "\n " .join (failures )
161+ assert not failures , failure_message
162+
163+ @staticmethod
164+ async def evaluate (
165+ agent_module : str ,
166+ eval_dataset_file_path_or_dir : str ,
167+ num_runs : int = NUM_RUNS ,
168+ agent_name : Optional [str ] = None ,
169+ initial_session_file : Optional [str ] = None ,
170+ print_detailed_results : bool = True ,
171+ ):
172+ """Evaluates an Agent given eval data with custom metrics."""
173+ test_files = []
174+ if isinstance (eval_dataset_file_path_or_dir , str ) and os .path .isdir (
175+ eval_dataset_file_path_or_dir
176+ ):
177+ for root , _ , files in os .walk (eval_dataset_file_path_or_dir ):
178+ for file in files :
179+ if file .endswith (".test.json" ):
180+ test_files .append (os .path .join (root , file ))
181+ else :
182+ test_files = [eval_dataset_file_path_or_dir ]
183+
184+ initial_session = CustomMetricsSupportAgentEvaluator ._get_initial_session (
185+ initial_session_file
186+ )
187+
188+ for test_file in test_files :
189+ eval_config = CustomMetricsSupportAgentEvaluator .find_config_for_test_file (
190+ test_file
191+ )
192+ eval_set = CustomMetricsSupportAgentEvaluator ._load_eval_set_from_file (
193+ test_file , eval_config , initial_session
194+ )
195+
196+ await CustomMetricsSupportAgentEvaluator .evaluate_eval_set (
197+ agent_module = agent_module ,
198+ eval_set = eval_set ,
199+ eval_config = eval_config ,
200+ num_runs = num_runs ,
201+ agent_name = agent_name ,
202+ print_detailed_results = print_detailed_results ,
203+ )
0 commit comments