Skip to content

Commit 2820366

Browse files
ltiaofacebook-github-bot
authored andcommitted
Refactor get_trace and extend get_opt_trace_by_steps to MOO/constrained (#4884)
Summary: We have a method `get_opt_trace_by_steps` that was used extensively during our Ax 1.0 benchmarking campaign. It duplicates the basic logic of `get_trace` but differs in that it operates along `(trial_index, MAP_KEY)` pairs and respects ordering by timestamp (i.e. chronological order). However, it is limited to single-objective unconstrained problems, and our current needs (multi-objective and/or constrained) have outgrown it. We reconcile the two by extracting three core building blocks of `get_trace`: 1. `_pivot_data_with_feasibility`: Pivots data to wide format with feasibility information and metric completeness checks. 2. `_compute_trace_values`: Computes per-observation trace values (hypervolume for MOO, objective value for SOO), with cumulative best support. 3. `_aggregate_and_cumulate_trace`: Aggregates values by groups and computes the cumulative best across groups. These are implemented in a more general way that respects arbitrary groupings and orderings. We then refactor `get_trace` (and its helpers `_prepare_data_for_trace` and `get_trace_by_arm_pull_from_data`) to use these building blocks, and leverage them in `get_opt_trace_by_steps` to extend its support to multi-objective and constrained problems. Additionally: - The timestamp-based sorting in `get_opt_trace_by_steps` is preserved, which is critical for correct cumulative hypervolume computation (without this, observations would be processed in `(trial_index, arm_name, MAP_KEY)` order instead of chronological order). - Tests are updated to replace `NotImplementedError` checks with actual MOO and constrained test cases that verify correctness of the new functionality. Reviewed By: dme65 Differential Revision: D79581270
1 parent 71c9895 commit 2820366

4 files changed

Lines changed: 348 additions & 129 deletions

File tree

ax/benchmark/benchmark.py

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,13 @@
5454
from ax.generation_strategy.generation_strategy import GenerationStrategy
5555
from ax.orchestration.orchestrator import Orchestrator
5656
from ax.service.utils.best_point import (
57+
_aggregate_and_cumulate_trace,
58+
_compute_trace_values,
59+
_pivot_data_with_feasibility,
5760
_prepare_data_for_trace,
5861
derelativize_opt_config,
5962
get_trace,
63+
is_row_feasible,
6064
)
6165
from ax.service.utils.best_point_mixin import BestPointMixin
6266
from ax.service.utils.orchestrator_options import OrchestratorOptions, TrialType
@@ -794,64 +798,80 @@ def get_opt_trace_by_steps(experiment: Experiment) -> npt.NDArray:
794798
that is in terms of steps, with one element added each time a step
795799
completes.
796800
801+
Supports single-objective, multi-objective, and constrained problems.
802+
For multi-objective problems, the trace is in terms of hypervolume.
803+
797804
Args:
798805
experiment: An experiment produced by `benchmark_replication`; it must
799806
have `BenchmarkTrialMetadata` (as produced by `BenchmarkRunner`) for
800807
each trial, and its data must have a "step" column.
801808
"""
802809
optimization_config = none_throws(experiment.optimization_config)
810+
full_df = experiment.lookup_data().full_df
803811

804-
if optimization_config.is_moo_problem:
805-
raise NotImplementedError(
806-
"Cumulative epochs only supported for single objective problems."
807-
)
808-
if len(optimization_config.outcome_constraints) > 0:
809-
raise NotImplementedError(
810-
"Cumulative epochs not supported for problems with outcome constraints."
811-
)
812+
full_df["row_feasible"] = is_row_feasible(
813+
df=full_df,
814+
optimization_config=optimization_config,
815+
# For the sake of this function, we only care about feasible trials. The
816+
# distinction between infeasible and undetermined is not important.
817+
undetermined_value=False,
818+
)
812819

813-
objective_name = optimization_config.objective.metric.name
814-
data = experiment.lookup_data()
815-
full_df = data.full_df
820+
# Pivot to wide format with feasibility
821+
df_wide = _pivot_data_with_feasibility(
822+
df=full_df,
823+
index=["trial_index", "arm_name", MAP_KEY],
824+
optimization_config=optimization_config,
825+
)
816826

817-
# Has timestamps; needs to be merged with full_df because it contains
818-
# data on epochs that didn't actually run due to early stopping, and we need
819-
# to know which actually ran
820-
def _get_df(trial: Trial) -> pd.DataFrame:
827+
def _get_timestamps(experiment: Experiment) -> pd.Series:
821828
"""
822-
Get the (virtual) time each epoch finished at.
829+
Get the (virtual) time at which each training progression finished.
823830
"""
824-
metadata = trial.run_metadata["benchmark_metadata"]
825-
backend_simulator = none_throws(metadata.backend_simulator)
826-
# Data for the first metric, which is the only metric
827-
df = next(iter(metadata.dfs.values()))
828-
start_time = backend_simulator.get_sim_trial_by_index(
829-
trial.index
830-
).sim_start_time
831-
df["time"] = df["virtual runtime"] + start_time
832-
return df
833-
834-
with_timestamps = pd.concat(
835-
(
836-
_get_df(trial=assert_is_instance(trial, Trial))
837-
for trial in experiment.trials.values()
838-
),
839-
axis=0,
840-
ignore_index=True,
841-
)[["trial_index", MAP_KEY, "time"]]
842-
843-
df = (
844-
full_df.loc[
845-
full_df["metric_name"] == objective_name,
846-
["trial_index", "arm_name", "mean", MAP_KEY],
847-
]
848-
.merge(with_timestamps, how="left")
849-
.sort_values("time", ignore_index=True)
831+
frames = []
832+
for trial in experiment.trials.values():
833+
trial = assert_is_instance(trial, Trial)
834+
metadata = trial.run_metadata["benchmark_metadata"]
835+
backend_simulator = none_throws(metadata.backend_simulator)
836+
sim_trial = backend_simulator.get_sim_trial_by_index(
837+
trial_index=trial.index
838+
)
839+
start_time = sim_trial.sim_start_time
840+
# timestamps are identical across all metrics, so just use the first one
841+
frame = next(iter(metadata.dfs.values())).copy()
842+
frame["time"] = frame["virtual runtime"] + start_time
843+
frames.append(frame)
844+
df = pd.concat(frames, axis=0, ignore_index=True).set_index(
845+
["trial_index", "arm_name", MAP_KEY]
846+
)
847+
return df["time"]
848+
849+
# Compute timestamps and join with df_wide *before* cumulative computations.
850+
# This is critical because cumulative HV/objective calculations depend on
851+
# the temporal ordering of observations.
852+
timestamps = _get_timestamps(experiment=experiment)
853+
854+
# Merge timestamps and sort by time before cumulative computations
855+
df_wide = df_wide.join(
856+
timestamps, on=["trial_index", "arm_name", MAP_KEY], how="left"
857+
).sort_values(by="time", ascending=True, ignore_index=True)
858+
859+
# Compute per-evaluation (trial_index, MAP_KEY) cumulative values,
860+
# with keep_order=True to preserve ordering by timestamp
861+
df_wide["value"], maximize = _compute_trace_values(
862+
df_wide=df_wide,
863+
optimization_config=optimization_config,
864+
use_cumulative_best=True,
850865
)
851-
return (
852-
df["mean"].cummin()
853-
if optimization_config.objective.minimize
854-
else df["mean"].cummax()
866+
# Get a value for each (trial_index, arm_name, MAP_KEY) tuple
867+
value_by_arm_pull = df_wide[["trial_index", "arm_name", MAP_KEY, "value"]]
868+
869+
# Aggregate by trial and step, then compute cumulative best
870+
return _aggregate_and_cumulate_trace(
871+
df=value_by_arm_pull,
872+
by=["trial_index", MAP_KEY],
873+
maximize=maximize,
874+
keep_order=True,
855875
).to_numpy()
856876

857877

@@ -870,15 +890,16 @@ def get_benchmark_result_with_cumulative_steps(
870890
opt_trace = get_opt_trace_by_steps(experiment=experiment)
871891
return replace(
872892
result,
873-
optimization_trace=opt_trace,
874-
cost_trace=np.arange(1, len(opt_trace) + 1, dtype=int),
893+
optimization_trace=opt_trace.tolist(),
894+
cost_trace=np.arange(1, len(opt_trace) + 1, dtype=int).tolist(),
875895
num_trials=list(range(1, len(opt_trace) + 1)),
876896
# Empty
877-
oracle_trace=np.full(len(opt_trace), np.nan),
878-
inference_trace=np.full(len(opt_trace), np.nan),
897+
oracle_trace=np.full_like(opt_trace, np.nan).tolist(),
898+
inference_trace=np.full_like(opt_trace, np.nan).tolist(),
899+
is_feasible_trace=None,
879900
score_trace=compute_score_trace(
880901
optimization_trace=opt_trace,
881902
baseline_value=baseline_value,
882903
optimal_value=optimal_value,
883-
),
904+
).tolist(),
884905
)

ax/benchmark/testing/benchmark_stubs.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,52 @@ def get_async_benchmark_problem(
313313
n_steps: int = 1,
314314
lower_is_better: bool = False,
315315
report_inference_value_as_trace: bool = False,
316+
num_objectives: int = 1,
317+
num_constraints: int = 0,
316318
) -> BenchmarkProblem:
319+
"""
320+
Create an early-stopping benchmark problem with MAP_KEY data.
321+
322+
Args:
323+
map_data: Whether to use map metrics (required for early stopping).
324+
step_runtime_fn: Optional runtime function for steps.
325+
n_steps: Number of steps per trial.
326+
lower_is_better: Whether lower values are better (for SOO).
327+
report_inference_value_as_trace: Whether to report inference trace.
328+
num_objectives: Number of objectives (1 for SOO, >1 for MOO).
329+
num_constraints: Number of outcome constraints to add.
330+
331+
Returns:
332+
A BenchmarkProblem suitable for early-stopping evaluation.
333+
"""
317334
search_space = get_discrete_search_space()
318-
test_function = IdentityTestFunction(n_steps=n_steps)
319-
optimization_config = get_soo_opt_config(
320-
outcome_names=["objective"],
321-
use_map_metric=map_data,
322-
observe_noise_sd=True,
323-
lower_is_better=lower_is_better,
324-
)
335+
336+
# Create outcome names for objectives and constraints
337+
objective_names = [f"objective_{i}" for i in range(num_objectives)]
338+
constraint_names = [f"constraint_{i}" for i in range(num_constraints)]
339+
outcome_names = [*objective_names, *constraint_names]
340+
341+
test_function = IdentityTestFunction(n_steps=n_steps, outcome_names=outcome_names)
342+
343+
if num_objectives == 1:
344+
# Single-objective: first outcome is objective, rest are constraints
345+
optimization_config = get_soo_opt_config(
346+
outcome_names=outcome_names,
347+
lower_is_better=lower_is_better,
348+
observe_noise_sd=True,
349+
use_map_metric=map_data,
350+
)
351+
else:
352+
# Multi-objective: pass all outcomes (objectives + constraints)
353+
# get_moo_opt_config will use the last num_constraints as constraints
354+
optimization_config = get_moo_opt_config(
355+
outcome_names=outcome_names,
356+
ref_point=[1.0] * num_objectives,
357+
num_constraints=num_constraints,
358+
lower_is_better=lower_is_better,
359+
observe_noise_sd=True,
360+
use_map_metric=map_data,
361+
)
325362

326363
return BenchmarkProblem(
327364
name="test",
@@ -331,6 +368,9 @@ def get_async_benchmark_problem(
331368
num_trials=4,
332369
baseline_value=19 if lower_is_better else 0,
333370
optimal_value=0 if lower_is_better else 19,
371+
worst_feasible_value=(19 if lower_is_better else 0)
372+
if num_constraints > 0
373+
else None,
334374
step_runtime_function=step_runtime_fn,
335375
report_inference_value_as_trace=report_inference_value_as_trace,
336376
)

ax/benchmark/tests/test_benchmark.py

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,28 +1214,85 @@ def test_get_opt_trace_by_cumulative_epochs(self) -> None:
12141214
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
12151215
self.assertEqual(list(new_opt_trace), [0.0, 0.0, 1.0, 1.0, 2.0, 3.0])
12161216

1217-
method = get_sobol_benchmark_method()
1218-
with self.subTest("MOO"):
1219-
problem = get_multi_objective_benchmark_problem()
1220-
1217+
with self.subTest("Multi-objective"):
1218+
# Multi-objective problem with step data
1219+
problem = get_async_benchmark_problem(
1220+
map_data=True,
1221+
n_steps=5,
1222+
num_objectives=2,
1223+
# Ensure we don't have two finishing at the same time, for
1224+
# determinism
1225+
step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]),
1226+
)
12211227
experiment = self.run_optimization_with_orchestrator(
12221228
problem=problem, method=method, seed=0
12231229
)
1224-
with self.assertRaisesRegex(
1225-
NotImplementedError, "only supported for single objective"
1226-
):
1227-
get_opt_trace_by_steps(experiment=experiment)
1230+
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
1231+
self.assertListEqual(
1232+
new_opt_trace.tolist(),
1233+
[
1234+
0.0,
1235+
0.0,
1236+
0.0,
1237+
0.0,
1238+
0.0,
1239+
0.0,
1240+
0.0,
1241+
1.0,
1242+
1.0,
1243+
1.0,
1244+
1.0,
1245+
1.0,
1246+
1.0,
1247+
4.0,
1248+
4.0,
1249+
4.0,
1250+
4.0,
1251+
4.0,
1252+
4.0,
1253+
4.0,
1254+
],
1255+
)
12281256

12291257
with self.subTest("Constrained"):
1230-
problem = get_benchmark_problem("constrained_gramacy_observed_noise")
1258+
# Constrained problem with step data.
1259+
problem = get_async_benchmark_problem(
1260+
map_data=True,
1261+
n_steps=5,
1262+
num_constraints=1,
1263+
# Ensure we don't have two finishing at the same time, for
1264+
# determinism
1265+
step_runtime_fn=lambda params: params["x0"] * (1 - 0.01 * params["x0"]),
1266+
)
12311267
experiment = self.run_optimization_with_orchestrator(
12321268
problem=problem, method=method, seed=0
12331269
)
1234-
with self.assertRaisesRegex(
1235-
NotImplementedError,
1236-
"not supported for problems with outcome constraints",
1237-
):
1238-
get_opt_trace_by_steps(experiment=experiment)
1270+
new_opt_trace = get_opt_trace_by_steps(experiment=experiment)
1271+
self.assertListEqual(
1272+
new_opt_trace.tolist(),
1273+
[
1274+
0.0,
1275+
0.0,
1276+
0.0,
1277+
0.0,
1278+
0.0,
1279+
1.0,
1280+
1.0,
1281+
2.0,
1282+
2.0,
1283+
2.0,
1284+
2.0,
1285+
2.0,
1286+
2.0,
1287+
3.0,
1288+
3.0,
1289+
3.0,
1290+
3.0,
1291+
3.0,
1292+
3.0,
1293+
3.0,
1294+
],
1295+
)
12391296

12401297
def test_get_benchmark_result_with_cumulative_steps(self) -> None:
12411298
"""See test_get_opt_trace_by_cumulative_epochs for more info."""

0 commit comments

Comments
 (0)