Skip to content

Commit e6bf117

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Fix minor bugs across api, benchmark, generators, metrics, runners, and other modules
Summary: Fix bugs in smaller modules with 1-2 files each: - Client JSON snapshot crash when no GenerationStrategy exists: `"generation_strategy" in snapshot` is True for the key with a None value, causing a decode crash. Changed to `snapshot.get("generation_strategy") is not None` (api/client.py:1351) - Benchmark ternary discarding input in both branches: `{} if target_fidelity_and_task is None else {}` always produces an empty dict regardless of input (benchmark/benchmark.py:701) - String `"False"` default (truthy) preventing trials from being marked CANDIDATE: `t.run_metadata.get(STARTED_KEY, "False")` returns the string `"False"` which is truthy, so `not "False"` is always False (runners/map_replay.py:44) - Operator precedence making `noisy=False` set all means to 0.0: `item["mean"] + noise if noisy else 0.0` parses as `(item["mean"] + noise) if noisy else 0.0`, replacing the entire mean with 0.0 when `noisy=False` (metrics/noisy_function_map.py:86, metrics/branin_map.py:119) - Misleading "total number of trials" vs actual trial index (global_stopping/strategies/improvement.py:132) + test update - Typos: "trial rial" → "trial" (fb/tutorials/auto_tuning.py), "prefererd" → "preferred" and stray `f` in f-string (fb/utils/preference/preference.py), "acquistion" → "acquisition" (generators/torch/botorch_modular/utils.py + test) Differential Revision: D92879402
1 parent e5c85eb commit e6bf117

10 files changed

Lines changed: 15 additions & 31 deletions

File tree

ax/api/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ def _from_json_snapshot(
13501350
decoder_registry=decoder_registry,
13511351
class_decoder_registry=class_decoder_registry,
13521352
)
1353-
if "generation_strategy" in snapshot
1353+
if snapshot.get("generation_strategy") is not None
13541354
else None
13551355
)
13561356

ax/benchmark/benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,9 @@ def compute_baseline_value_from_sobol(
698698
n_repeats: Number of times to repeat the five Sobol trials.
699699
"""
700700
method = get_sobol_benchmark_method()
701-
target_fidelity_and_task = {} if target_fidelity_and_task is None else {}
701+
target_fidelity_and_task = (
702+
{} if target_fidelity_and_task is None else target_fidelity_and_task
703+
)
702704

703705
# set up a dummy problem so we can use `benchmark_replication`
704706
# MOO problems are always higher-is-better because they use hypervolume

ax/early_stopping/strategies/base.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Iterable, Sequence
11+
from collections.abc import Iterable
1212
from logging import Logger
1313
from typing import cast
1414

@@ -665,19 +665,3 @@ def _lookup_and_validate_data(
665665
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
666666
map_data = Data(df=full_df)
667667
return map_data
668-
669-
def get_training_data(
670-
self,
671-
experiment: Experiment,
672-
map_data: Data,
673-
max_training_size: int | None = None,
674-
outcomes: Sequence[str] | None = None,
675-
parameters: list[str] | None = None,
676-
) -> None:
677-
# Deprecated in Ax 1.1.0, so should be removed in Ax 1.2.0+.
678-
raise DeprecationWarning(
679-
"`ModelBasedEarlyStoppingStrategy.get_training_data` is deprecated. "
680-
"Subclasses should either extract the training data manually, "
681-
"or rely on the fitted surrogates available in the current generation "
682-
"node that is passed into `should_stop_trials_early`."
683-
)

ax/generators/torch/botorch_modular/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def construct_acquisition_and_optimizer_options(
466466
if len(botorch_acqf_classes_with_options) > 1:
467467
warnings.warn(
468468
message="botorch_acqf_options are being ignored, due to using "
469-
"MultiAcquisition. Specify options for each acquistion function"
469+
"MultiAcquisition. Specify options for each acquisition function "
470470
"via botorch_acqf_classes_with_options.",
471471
category=AxWarning,
472472
stacklevel=4,

ax/generators/torch/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_construct_acquisition_and_optimizer_options(self) -> None:
465465
self.assertEqual(
466466
str(warning.message),
467467
"botorch_acqf_options are being ignored, due to using "
468-
"MultiAcquisition. Specify options for each acquistion function"
468+
"MultiAcquisition. Specify options for each acquisition function "
469469
"via botorch_acqf_classes_with_options.",
470470
)
471471

ax/global_stopping/strategies/improvement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def _should_stop_optimization(
130130
trial_to_check = max_completed_trial
131131
elif trial_to_check > max_completed_trial:
132132
raise ValueError(
133-
"trial_to_check is larger than the total number of "
134-
f"trials (={max_completed_trial})."
133+
"trial_to_check is larger than the maximum completed "
134+
f"trial index (={max_completed_trial})."
135135
)
136136

137137
# Only counting the trials up to trial_to_check.

ax/global_stopping/tests/test_strategies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_base_cases(self) -> None:
8585
# Should raise ValueError if trying to check an invalid trial
8686
with self.assertRaisesRegex(
8787
ValueError,
88-
r"trial_to_check is larger than the total number of trials \(=4\).",
88+
r"trial_to_check is larger than the maximum completed trial index \(=4\).",
8989
):
9090
stop, message = gss.should_stop_optimization(
9191
experiment=exp, trial_to_check=5

ax/metrics/branin_map.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ def fetch_trial_data(
117117
"sem": self.noise_sd if noisy else 0.0,
118118
"trial_index": trial.index,
119119
"mean": [
120-
item["mean"] + self.noise_sd * np.random.randn()
121-
if noisy
122-
else 0.0
120+
item["mean"]
121+
+ (self.noise_sd * np.random.randn() if noisy else 0.0)
123122
for item in res
124123
],
125124
"metric_signature": self.signature,

ax/metrics/noisy_function_map.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,8 @@ def fetch_trial_data(
8484
"sem": self.noise_sd if noisy else 0.0,
8585
"trial_index": trial.index,
8686
"mean": [
87-
item["mean"] + self.noise_sd * np.random.randn()
88-
if noisy
89-
else 0.0
87+
item["mean"]
88+
+ (self.noise_sd * np.random.randn() if noisy else 0.0)
9089
for item in res
9190
],
9291
"metric_signature": self.signature,

ax/runners/map_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def poll_trial_status(
4141
# depending on whether or not there is more data available,
4242
# mark it either RUNNING or COMPLETED.
4343
for t in trials:
44-
if not t.run_metadata.get(STARTED_KEY, "False"):
44+
if not t.run_metadata.get(STARTED_KEY, False):
4545
result[TrialStatus.CANDIDATE].add(t.index)
4646
elif not self.replay_metric.has_trial_data(t.index):
4747
result[TrialStatus.ABANDONED].add(t.index)

0 commit comments

Comments
 (0)