Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions ax/adapter/adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
get_weighted_mc_objective_and_objective_thresholds,
pareto_frontier_evaluator,
)
from ax.utils.common.constants import Keys
from ax.utils.common.hash_utils import get_current_lilo_hash
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import (
assert_is_instance_of_tuple,
Expand Down Expand Up @@ -1270,6 +1272,57 @@ def process_contextual_datasets(
return contextual_datasets


def _get_fresh_pairwise_trial_indices(
experiment: Experiment,
) -> set[int] | None:
"""Return trial indices whose pairwise labels match current experiment state.

LILO (Language-in-the-Loop) trials are stamped with a hash of the
experiment state (metric data + LLM messages) at labeling time. When
the experiment state changes (new data, updated LLM messages), old labels
become stale and should be excluded from PairwiseGP model fitting.

Design note: we intentionally compare each trial's stamped hash against
the *current* experiment state rather than the most-recently-stamped LILO
hash. This is because the LLM prompt includes a full experiment summary
(via ``get_llm_messages_with_experiment_summary``), so any change to
input metric data -- even from non-LILO trials -- alters the context
under which labels would be produced and warrants relabeling.

Returns:
A set of trial indices whose LILO input hash matches the current
experiment state, or ``None`` if hash-based filtering is not
applicable (e.g., no trials have a LILO input hash -- the experiment
uses BOPE or another non-LILO pairwise workflow).
"""
# Collect trials that have been stamped with a LILO input hash.
stamped_trials = {
idx: trial
for idx, trial in experiment.trials.items()
if Keys.LILO_INPUT_HASH in trial._properties
}
if not stamped_trials:
# Not a LILO experiment -- no filtering needed.
return None

current_hash = get_current_lilo_hash(experiment)
if current_hash is None:
return None

fresh_indices: set[int] = set()
for idx, trial in experiment.trials.items():
trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH)
if trial_hash is None:
# Trial without hash (non-LILO trial) -- always include.
fresh_indices.add(idx)
elif trial_hash == current_hash:
# Hash matches -- labels are fresh.
fresh_indices.add(idx)
# else: stale hash -- excluded.

return fresh_indices


def prep_pairwise_data(
X: Tensor,
Y: Tensor,
Expand Down
85 changes: 85 additions & 0 deletions ax/adapter/tests/test_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@


import numpy as np
import pandas as pd
import torch
from ax.adapter.adapter_utils import (
_get_adapter_training_data,
_get_fresh_pairwise_trial_indices,
arm_to_np_array,
can_map_to_binary,
extract_objective_weight_matrix,
Expand All @@ -25,6 +27,9 @@
from ax.adapter.torch import TorchAdapter
from ax.adapter.transforms.choice_encode import ChoiceToNumericChoice
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.derived_metric import DerivedMetric
from ax.core.experiment import Experiment
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
Expand All @@ -34,6 +39,8 @@
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.utils.common.constants import Keys
from ax.utils.common.hash_utils import compute_lilo_input_hash
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_experiment_with_observations,
Expand Down Expand Up @@ -555,3 +562,81 @@ def test_extract_objective_weight_matrix(self) -> None:
)
result = extract_objective_weight_matrix(multi, outcomes)
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])

def test_get_fresh_pairwise_trial_indices(self) -> None:
"""Verify _get_fresh_pairwise_trial_indices hash-based filtering."""
search_space = get_search_space_for_range_values()
exp = Experiment(name="test", search_space=search_space)

# Register a DerivedMetric with pairwise name so the function can
# look up input_metric_names.
pairwise_metric = DerivedMetric(
name=Keys.PAIRWISE_PREFERENCE_QUERY.value,
input_metric_names=["latency"],
)
exp.add_tracking_metric(pairwise_metric)

# Helper to create trial data.
def _attach(
trial_index: int, arms: dict[str, float], exp: Experiment = exp
) -> None:
rows = [
{
"trial_index": trial_index,
"arm_name": name,
"metric_name": "latency",
"metric_signature": "latency",
"mean": val,
"sem": 0.1,
}
for name, val in arms.items()
]
exp.attach_data(Data(df=pd.DataFrame(rows)))

# Create two trials with data.
for i in range(2):
trial = exp.new_batch_trial()
trial.add_arm(Arm(name=f"{i}_0", parameters={"x": float(i)}))
trial.mark_running(no_runner_required=True)
trial.mark_completed()
_attach(i, {f"{i}_0": float(i + 1)})

with self.subTest("no_hashes_returns_none"):
# No trials have LILO_INPUT_HASH -- not a LILO experiment.
result = _get_fresh_pairwise_trial_indices(exp)
self.assertIsNone(result)

# Stamp trial 0 with the current hash.
current_hash = compute_lilo_input_hash(exp, ["latency"])
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash

with self.subTest("fresh_hash_included"):
result = _get_fresh_pairwise_trial_indices(exp)
assert result is not None
self.assertIn(0, result)
# Trial 1 has no hash -- always included.
self.assertIn(1, result)

# Stamp trial 1 with a stale hash.
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash_value"

with self.subTest("stale_hash_excluded"):
result = _get_fresh_pairwise_trial_indices(exp)
assert result is not None
self.assertIn(0, result)
self.assertNotIn(1, result)

with self.subTest("all_stale"):
# Make both hashes stale by adding new data.
trial2 = exp.new_batch_trial()
trial2.add_arm(Arm(name="2_0", parameters={"x": 10.0}))
trial2.mark_running(no_runner_required=True)
trial2.mark_completed()
_attach(2, {"2_0": 999.0})
# Now both trial 0 and trial 1 have stale hashes.
result = _get_fresh_pairwise_trial_indices(exp)
assert result is not None
# Trial 0 and 1 are stale, trial 2 has no hash -- included.
self.assertNotIn(0, result)
self.assertNotIn(1, result)
self.assertIn(2, result)
21 changes: 21 additions & 0 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy.typing as npt
import torch
from ax.adapter.adapter_utils import (
_get_fresh_pairwise_trial_indices,
arm_to_np_array,
array_to_observation_data,
extract_objective_thresholds,
Expand Down Expand Up @@ -468,6 +469,26 @@ def _convert_experiment_data(
Yvar = torch.from_numpy(sem).double().square().view(-1, 1)
group_indices = torch.from_numpy(trial_indices_np[to_keep])
if outcome == Keys.PAIRWISE_PREFERENCE_QUERY.value:
# Filter out stale LILO trials whose input hash no longer
# matches the current experiment state.
fresh_indices = _get_fresh_pairwise_trial_indices(
experiment=self._experiment,
)
if fresh_indices is not None:
fresh_mask = torch.tensor(
[int(gi.item()) in fresh_indices for gi in group_indices],
dtype=torch.bool,
)
X = X[fresh_mask]
Y = Y[fresh_mask]
group_indices = group_indices[fresh_mask]
# Narrow the NaN-filtered to_keep mask further so
# candidate_metadata stays aligned.
to_keep_indices = np.where(to_keep)[0]
fresh_mask_np = fresh_mask.numpy()
to_keep = np.zeros_like(to_keep)
to_keep[to_keep_indices[fresh_mask_np]] = True

dataset = prep_pairwise_data(
X=X.to(device=self.device),
Y=Y.to(dtype=torch.long, device=self.device),
Expand Down
Loading
Loading