11import logging
2- from typing import Optional , Callable
2+ from collections .abc import Callable
3+ from typing import Optional
34
45import torch
56
67from eaa_core .api .llm_config import LLMConfig
78from eaa_core .api .memory import MemoryManagerConfig
89from eaa_core .task_manager .base import BaseTaskManager
910from eaa_core .tool .base import BaseTool
10-
1111from eaa_core .tool .optimization import BayesianOptimizationTool
12+ from eaa_core .util import to_tensor
1213
1314logger = logging .getLogger (__name__ )
1415
1516
17+ class BayesianOptimizationStoppingCriterion :
18+ """Base stopping criterion for Bayesian optimization task managers."""
19+
20+ reason : str = ""
21+
22+ def should_stop (self , task_manager : "BayesianOptimizationTaskManager" ) -> bool :
23+ """Return whether the optimization loop should stop."""
24+ raise NotImplementedError
25+
26+
27+ class MaxObservationsStoppingCriterion (BayesianOptimizationStoppingCriterion ):
28+ """Stop once a maximum number of observations has been collected."""
29+
30+ def __init__ (self , max_observations : int ) -> None :
31+ """Initialize the stopping criterion.
32+
33+ Parameters
34+ ----------
35+ max_observations : int
36+ Maximum number of observed x/y pairs allowed.
37+ """
38+ self .max_observations = max_observations
39+ self .reason = "max_observations_reached"
40+
41+ def should_stop (self , task_manager : "BayesianOptimizationTaskManager" ) -> bool :
42+ """Return whether the observation cap has been reached."""
43+ n_observations = int (task_manager .bayesian_optimization_tool .xs_untransformed .shape [0 ])
44+ return n_observations >= self .max_observations
45+
46+
1647class BayesianOptimizationTaskManager (BaseTaskManager ):
17-
48+ """Task manager that runs an outer-loop Bayesian optimization workflow."""
49+
1850 def __init__ (
1951 self ,
2052 llm_config : LLMConfig = None ,
@@ -23,10 +55,13 @@ def __init__(
2355 additional_tools : list [BaseTool ] = (),
2456 initial_points : Optional [torch .Tensor ] = None ,
2557 n_initial_points : int = 20 ,
26- objective_function : Callable = None ,
58+ objective_function : BaseTool | Callable | None = None ,
59+ objective_function_method : str | None = None ,
60+ stopping_criteria : Optional [list [BayesianOptimizationStoppingCriterion ]] = None ,
2761 session_db_path : Optional [str ] = "session.sqlite" ,
2862 build : bool = True ,
29- * args , ** kwargs
63+ * args ,
64+ ** kwargs ,
3065 ) -> None :
3166 """Bayesian optimization task manager.
3267
@@ -37,27 +72,28 @@ def __init__(
3772 memory_config : MemoryManagerConfig, optional
3873 Memory configuration forwarded to the agent.
3974 additional_tools : list[BaseTool], optional
40- A list of tools for the agent (not including the
41- `BayesianOptimizationTool`) .
75+ Additional tools exposed to the task manager, excluding the
76+ Bayesian optimization tool and the objective tool .
4277 bayesian_optimization_tool : BayesianOptimizationTool
4378 The Bayesian optimization tool to use.
4479 initial_points : torch.Tensor, optional
45- A (n_points, n_features) tensor giving the initial points where
46- the objective function should be evaluated to initialize the
47- Gaussian process model. If None, random initial points will be
48- generated.
80+ Initial measurement points with shape ``(n_points, n_features)``.
81+ When omitted, random points are drawn from the optimization bounds.
4982 n_initial_points : int, optional
50- The number of initial points to generate if `initial_points` is None.
51- objective_function : Callable
52- The objective function to be maximized. This function should take
53- a single argument, which is a (n_points, n_features) tensor of
54- points to evaluate the objective function at. It should return
55- a (n_points, n_objectives) tensor of objective function values.
83+ Number of random initial points to draw when ``initial_points`` is not
84+ provided.
85+ objective_function : BaseTool | Callable
86+ Callable or tool used to evaluate points. The returned observations
87+ must have shape ``(n_samples, n_observations)``.
88+ objective_function_method : str | None, optional
89+ Method name to call when ``objective_function`` is a tool. When
90+ omitted, ``measure`` is preferred, then ``evaluate``, then a single
91+ exposed tool method.
92+ stopping_criteria : list[BayesianOptimizationStoppingCriterion], optional
93+ Additional stopping criteria checked after initialization and each
94+ update.
5695 session_db_path : Optional[str]
57- If provided, the entire chat history will be stored in
58- a SQLite database at the given path. This is essential
59- if you want to use the WebUI, which polls the database
60- for new messages.
96+ Optional SQLite path used by the shared chat/task-manager session.
6197 build : bool, optional
6298 Whether to build the internal state of the task manager.
6399 """
@@ -68,59 +104,149 @@ def __init__(
68104 )
69105 if objective_function is None :
70106 raise ValueError ("`objective_function` is required." )
71-
107+
72108 self .bayesian_optimization_tool = bayesian_optimization_tool
73-
74- for tool in additional_tools :
109+
110+ tools = list (additional_tools )
111+ if isinstance (objective_function , BaseTool ):
112+ tools .append (objective_function )
113+ for tool in tools :
75114 if isinstance (tool , BayesianOptimizationTool ):
76115 raise ValueError (
77116 "`BayesianOptimizationTool` should not be included in `tools`. "
78117 "Instead, pass it to `bayesian_optimization_tool`."
79118 )
80-
119+
81120 self .objective_function = objective_function
82-
121+ self . objective_function_method = objective_function_method
83122 self .initial_points = initial_points
84123 self .n_initial_points = n_initial_points
85-
124+ self .stopping_criteria = list (stopping_criteria or [])
125+ self .stop_reason : str | None = None
126+
86127 super ().__init__ (
87128 llm_config = llm_config ,
88129 memory_config = memory_config ,
89- tools = additional_tools ,
130+ tools = tools ,
90131 session_db_path = session_db_path ,
91132 build = build ,
92- * args , ** kwargs
133+ * args ,
134+ ** kwargs ,
93135 )
94-
136+
95137 def run (
96- self ,
97- n_iterations : int = 50 ,
98- * args , ** kwargs
138+ self ,
139+ n_iterations : int = 50 ,
140+ * args ,
141+ ** kwargs ,
99142 ) -> None :
100- """Run Bayesian optimization. Upon the second or later call,
101- this function continues from the last iteration.
102-
143+ """Run Bayesian optimization.
144+
145+ When the task manager already contains observations, the optimization
146+ continues from the current state.
147+
103148 Parameters
104149 ----------
105150 n_iterations : int, optional
106- The number of iterations to run .
151+ Maximum number of BO iterations to execute in this call .
107152 """
108153 if len (self .bayesian_optimization_tool .xs_untransformed ) == 0 :
109- if self .initial_points is None :
110- xs_init = self .bayesian_optimization_tool .get_random_initial_points (n_points = self .n_initial_points )
111- else :
112- xs_init = self .initial_points
113- logger .info (f"Initial points (shape: { xs_init .shape } ):\n { xs_init } " )
114-
115- for x in xs_init :
116- x = x [None , :]
117- y = self .objective_function (x )
118- self .bayesian_optimization_tool .update (x , y )
119- self .bayesian_optimization_tool .build ()
120-
121- for i in range (n_iterations ):
154+ self .collect_initial_observations ()
155+
156+ if self .should_stop ():
157+ return
158+
159+ for _ in range (n_iterations ):
122160 candidates = self .bayesian_optimization_tool .suggest (n_suggestions = 1 )
123- logger .info (f "Candidate suggested: { candidates [0 ]} " )
124- y = self .objective_function (candidates )
125- logger .info (f "Objective function value: { y . item () } " )
161+ logger .info ("Candidate suggested: %s" , candidates [0 ])
162+ y = self .evaluate_objective (candidates )
163+ logger .info ("Objective function value: %s" , y . reshape ( - 1 ) )
126164 self .bayesian_optimization_tool .update (candidates , y )
165+ self .configure_bayesian_optimization ()
166+ if self .should_stop ():
167+ break
168+
169+ def collect_initial_observations (self ) -> None :
170+ """Collect initial observations and build the GP model.
171+
172+ Initial ``x`` points are expected to have shape ``(n_points, n_features)``
173+ and the evaluated observations must have shape
174+ ``(n_points, n_observations)``.
175+ """
176+ if self .initial_points is None :
177+ xs_init = self .bayesian_optimization_tool .get_random_initial_points (
178+ n_points = self .n_initial_points
179+ )
180+ else :
181+ xs_init = to_tensor (self .initial_points )
182+ logger .info ("Initial points (shape: %s):\n %s" , xs_init .shape , xs_init )
183+
184+ for x in xs_init :
185+ x = x [None , :]
186+ y = self .evaluate_objective (x )
187+ self .bayesian_optimization_tool .update (x , y )
188+ self .bayesian_optimization_tool .build ()
189+ self .configure_bayesian_optimization ()
190+
191+ def evaluate_objective (self , x : torch .Tensor ) -> torch .Tensor :
192+ """Evaluate the objective function or objective tool at input points.
193+
194+ Parameters
195+ ----------
196+ x : torch.Tensor
197+ Candidate locations with shape ``(n_samples, n_features)``.
198+
199+ Returns
200+ -------
201+ torch.Tensor
202+ Objective values with shape ``(n_samples, n_observations)``.
203+ """
204+ if isinstance (self .objective_function , BaseTool ):
205+ objective_callable = self .resolve_objective_tool_callable ()
206+ y = objective_callable (x )
207+ else :
208+ y = self .objective_function (x )
209+ y = to_tensor (y )
210+ if not isinstance (y , torch .Tensor ):
211+ y = torch .as_tensor (y )
212+ if y .ndim == 1 :
213+ y = y [:, None ]
214+ self .bayesian_optimization_tool .check_y_data (y )
215+ return y
216+
217+ def resolve_objective_tool_callable (self ) -> Callable :
218+ """Resolve the callable used to evaluate the objective tool."""
219+ tool = self .objective_function
220+ if not isinstance (tool , BaseTool ):
221+ raise TypeError ("`objective_function` is not a tool instance." )
222+
223+ candidate_names : list [str ]
224+ if self .objective_function_method is not None :
225+ candidate_names = [self .objective_function_method ]
226+ else :
227+ candidate_names = ["measure" , "evaluate" ]
228+
229+ for method_name in candidate_names :
230+ method = getattr (tool , method_name , None )
231+ if callable (method ):
232+ return method
233+
234+ if len (tool .exposed_tools ) == 1 :
235+ return tool .exposed_tools [0 ].function
236+ raise ValueError (
237+ "Could not resolve a tool method for `objective_function`. "
238+ "Pass `objective_function_method` explicitly."
239+ )
240+
241+ def configure_bayesian_optimization (self ) -> None :
242+ """Hook for subclasses to update acquisition or stopping state."""
243+ return None
244+
245+ def should_stop (self ) -> bool :
246+ """Return whether any configured stopping criterion has triggered."""
247+ for criterion in self .stopping_criteria :
248+ if criterion .should_stop (self ):
249+ self .stop_reason = criterion .reason
250+ logger .info ("Stopping criterion triggered: %s" , self .stop_reason )
251+ return True
252+ return False
0 commit comments