Skip to content

Commit 98d93a9

Browse files
committed
FEAT: add eaa-spectroscopy subpackage and XANES sampling
1 parent f81912c commit 98d93a9

17 files changed

Lines changed: 2911 additions & 60 deletions

File tree

Lines changed: 179 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,52 @@
11
import logging
2-
from typing import Optional, Callable
2+
from collections.abc import Callable
3+
from typing import Optional
34

45
import torch
56

67
from eaa_core.api.llm_config import LLMConfig
78
from eaa_core.api.memory import MemoryManagerConfig
89
from eaa_core.task_manager.base import BaseTaskManager
910
from eaa_core.tool.base import BaseTool
10-
1111
from eaa_core.tool.optimization import BayesianOptimizationTool
12+
from eaa_core.util import to_tensor
1213

1314
logger = logging.getLogger(__name__)
1415

1516

17+
class BayesianOptimizationStoppingCriterion:
18+
"""Base stopping criterion for Bayesian optimization task managers."""
19+
20+
reason: str = ""
21+
22+
def should_stop(self, task_manager: "BayesianOptimizationTaskManager") -> bool:
23+
"""Return whether the optimization loop should stop."""
24+
raise NotImplementedError
25+
26+
27+
class MaxObservationsStoppingCriterion(BayesianOptimizationStoppingCriterion):
28+
"""Stop once a maximum number of observations has been collected."""
29+
30+
def __init__(self, max_observations: int) -> None:
31+
"""Initialize the stopping criterion.
32+
33+
Parameters
34+
----------
35+
max_observations : int
36+
Maximum number of observed x/y pairs allowed.
37+
"""
38+
self.max_observations = max_observations
39+
self.reason = "max_observations_reached"
40+
41+
def should_stop(self, task_manager: "BayesianOptimizationTaskManager") -> bool:
42+
"""Return whether the observation cap has been reached."""
43+
n_observations = int(task_manager.bayesian_optimization_tool.xs_untransformed.shape[0])
44+
return n_observations >= self.max_observations
45+
46+
1647
class BayesianOptimizationTaskManager(BaseTaskManager):
17-
48+
"""Task manager that runs an outer-loop Bayesian optimization workflow."""
49+
1850
def __init__(
1951
self,
2052
llm_config: LLMConfig = None,
@@ -23,10 +55,13 @@ def __init__(
2355
additional_tools: list[BaseTool] = (),
2456
initial_points: Optional[torch.Tensor] = None,
2557
n_initial_points: int = 20,
26-
objective_function: Callable = None,
58+
objective_function: BaseTool | Callable | None = None,
59+
objective_function_method: str | None = None,
60+
stopping_criteria: Optional[list[BayesianOptimizationStoppingCriterion]] = None,
2761
session_db_path: Optional[str] = "session.sqlite",
2862
build: bool = True,
29-
*args, **kwargs
63+
*args,
64+
**kwargs,
3065
) -> None:
3166
"""Bayesian optimization task manager.
3267
@@ -37,27 +72,28 @@ def __init__(
3772
memory_config : MemoryManagerConfig, optional
3873
Memory configuration forwarded to the agent.
3974
additional_tools : list[BaseTool], optional
40-
A list of tools for the agent (not including the
41-
`BayesianOptimizationTool`).
75+
Additional tools exposed to the task manager, excluding the
76+
Bayesian optimization tool and the objective tool.
4277
bayesian_optimization_tool : BayesianOptimizationTool
4378
The Bayesian optimization tool to use.
4479
initial_points : torch.Tensor, optional
45-
A (n_points, n_features) tensor giving the initial points where
46-
the objective function should be evaluated to initialize the
47-
Gaussian process model. If None, random initial points will be
48-
generated.
80+
Initial measurement points with shape ``(n_points, n_features)``.
81+
When omitted, random points are drawn from the optimization bounds.
4982
n_initial_points : int, optional
50-
The number of initial points to generate if `initial_points` is None.
51-
objective_function : Callable
52-
The objective function to be maximized. This function should take
53-
a single argument, which is a (n_points, n_features) tensor of
54-
points to evaluate the objective function at. It should return
55-
a (n_points, n_objectives) tensor of objective function values.
83+
Number of random initial points to draw when ``initial_points`` is not
84+
provided.
85+
objective_function : BaseTool | Callable
86+
Callable or tool used to evaluate points. The returned observations
87+
must have shape ``(n_samples, n_observations)``.
88+
objective_function_method : str | None, optional
89+
Method name to call when ``objective_function`` is a tool. When
90+
omitted, ``measure`` is preferred, then ``evaluate``, then a single
91+
exposed tool method.
92+
stopping_criteria : list[BayesianOptimizationStoppingCriterion], optional
93+
Additional stopping criteria checked after initialization and each
94+
update.
5695
session_db_path : Optional[str]
57-
If provided, the entire chat history will be stored in
58-
a SQLite database at the given path. This is essential
59-
if you want to use the WebUI, which polls the database
60-
for new messages.
96+
Optional SQLite path used by the shared chat/task-manager session.
6197
build : bool, optional
6298
Whether to build the internal state of the task manager.
6399
"""
@@ -68,59 +104,149 @@ def __init__(
68104
)
69105
if objective_function is None:
70106
raise ValueError("`objective_function` is required.")
71-
107+
72108
self.bayesian_optimization_tool = bayesian_optimization_tool
73-
74-
for tool in additional_tools:
109+
110+
tools = list(additional_tools)
111+
if isinstance(objective_function, BaseTool):
112+
tools.append(objective_function)
113+
for tool in tools:
75114
if isinstance(tool, BayesianOptimizationTool):
76115
raise ValueError(
77116
"`BayesianOptimizationTool` should not be included in `tools`. "
78117
"Instead, pass it to `bayesian_optimization_tool`."
79118
)
80-
119+
81120
self.objective_function = objective_function
82-
121+
self.objective_function_method = objective_function_method
83122
self.initial_points = initial_points
84123
self.n_initial_points = n_initial_points
85-
124+
self.stopping_criteria = list(stopping_criteria or [])
125+
self.stop_reason: str | None = None
126+
86127
super().__init__(
87128
llm_config=llm_config,
88129
memory_config=memory_config,
89-
tools=additional_tools,
130+
tools=tools,
90131
session_db_path=session_db_path,
91132
build=build,
92-
*args, **kwargs
133+
*args,
134+
**kwargs,
93135
)
94-
136+
95137
def run(
96-
self,
97-
n_iterations: int = 50,
98-
*args, **kwargs
138+
self,
139+
n_iterations: int = 50,
140+
*args,
141+
**kwargs,
99142
) -> None:
100-
"""Run Bayesian optimization. Upon the second or later call,
101-
this function continues from the last iteration.
102-
143+
"""Run Bayesian optimization.
144+
145+
When the task manager already contains observations, the optimization
146+
continues from the current state.
147+
103148
Parameters
104149
----------
105150
n_iterations : int, optional
106-
The number of iterations to run.
151+
Maximum number of BO iterations to execute in this call.
107152
"""
108153
if len(self.bayesian_optimization_tool.xs_untransformed) == 0:
109-
if self.initial_points is None:
110-
xs_init = self.bayesian_optimization_tool.get_random_initial_points(n_points=self.n_initial_points)
111-
else:
112-
xs_init = self.initial_points
113-
logger.info(f"Initial points (shape: {xs_init.shape}):\n{xs_init}")
114-
115-
for x in xs_init:
116-
x = x[None, :]
117-
y = self.objective_function(x)
118-
self.bayesian_optimization_tool.update(x, y)
119-
self.bayesian_optimization_tool.build()
120-
121-
for i in range(n_iterations):
154+
self.collect_initial_observations()
155+
156+
if self.should_stop():
157+
return
158+
159+
for _ in range(n_iterations):
122160
candidates = self.bayesian_optimization_tool.suggest(n_suggestions=1)
123-
logger.info(f"Candidate suggested: {candidates[0]}")
124-
y = self.objective_function(candidates)
125-
logger.info(f"Objective function value: {y.item()}")
161+
logger.info("Candidate suggested: %s", candidates[0])
162+
y = self.evaluate_objective(candidates)
163+
logger.info("Objective function value: %s", y.reshape(-1))
126164
self.bayesian_optimization_tool.update(candidates, y)
165+
self.configure_bayesian_optimization()
166+
if self.should_stop():
167+
break
168+
169+
def collect_initial_observations(self) -> None:
170+
"""Collect initial observations and build the GP model.
171+
172+
Initial ``x`` points are expected to have shape ``(n_points, n_features)``
173+
and the evaluated observations must have shape
174+
``(n_points, n_observations)``.
175+
"""
176+
if self.initial_points is None:
177+
xs_init = self.bayesian_optimization_tool.get_random_initial_points(
178+
n_points=self.n_initial_points
179+
)
180+
else:
181+
xs_init = to_tensor(self.initial_points)
182+
logger.info("Initial points (shape: %s):\n%s", xs_init.shape, xs_init)
183+
184+
for x in xs_init:
185+
x = x[None, :]
186+
y = self.evaluate_objective(x)
187+
self.bayesian_optimization_tool.update(x, y)
188+
self.bayesian_optimization_tool.build()
189+
self.configure_bayesian_optimization()
190+
191+
def evaluate_objective(self, x: torch.Tensor) -> torch.Tensor:
192+
"""Evaluate the objective function or objective tool at input points.
193+
194+
Parameters
195+
----------
196+
x : torch.Tensor
197+
Candidate locations with shape ``(n_samples, n_features)``.
198+
199+
Returns
200+
-------
201+
torch.Tensor
202+
Objective values with shape ``(n_samples, n_observations)``.
203+
"""
204+
if isinstance(self.objective_function, BaseTool):
205+
objective_callable = self.resolve_objective_tool_callable()
206+
y = objective_callable(x)
207+
else:
208+
y = self.objective_function(x)
209+
y = to_tensor(y)
210+
if not isinstance(y, torch.Tensor):
211+
y = torch.as_tensor(y)
212+
if y.ndim == 1:
213+
y = y[:, None]
214+
self.bayesian_optimization_tool.check_y_data(y)
215+
return y
216+
217+
def resolve_objective_tool_callable(self) -> Callable:
218+
"""Resolve the callable used to evaluate the objective tool."""
219+
tool = self.objective_function
220+
if not isinstance(tool, BaseTool):
221+
raise TypeError("`objective_function` is not a tool instance.")
222+
223+
candidate_names: list[str]
224+
if self.objective_function_method is not None:
225+
candidate_names = [self.objective_function_method]
226+
else:
227+
candidate_names = ["measure", "evaluate"]
228+
229+
for method_name in candidate_names:
230+
method = getattr(tool, method_name, None)
231+
if callable(method):
232+
return method
233+
234+
if len(tool.exposed_tools) == 1:
235+
return tool.exposed_tools[0].function
236+
raise ValueError(
237+
"Could not resolve a tool method for `objective_function`. "
238+
"Pass `objective_function_method` explicitly."
239+
)
240+
241+
def configure_bayesian_optimization(self) -> None:
242+
"""Hook for subclasses to update acquisition or stopping state."""
243+
return None
244+
245+
def should_stop(self) -> bool:
246+
"""Return whether any configured stopping criterion has triggered."""
247+
for criterion in self.stopping_criteria:
248+
if criterion.should_stop(self):
249+
self.stop_reason = criterion.reason
250+
logger.info("Stopping criterion triggered: %s", self.stop_reason)
251+
return True
252+
return False

0 commit comments

Comments
 (0)