Skip to content

Commit 3aa7602

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Fix minor bugs across api, benchmark, generators, metrics, runners, and other modules (#4899)
Summary: Fix bugs in smaller modules with 1-2 files each: - Client JSON snapshot crash when no GenerationStrategy exists: `"generation_strategy" in snapshot` is True for the key with a None value, causing a decode crash. Changed to `snapshot.get("generation_strategy") is not None` (api/client.py:1351) - Benchmark ternary discarding input in both branches: `{} if target_fidelity_and_task is None else {}` always produces an empty dict regardless of input (benchmark/benchmark.py:701) - String `"False"` default (truthy) preventing trials from being marked CANDIDATE: `t.run_metadata.get(STARTED_KEY, "False")` returns the string `"False"` which is truthy, so `not "False"` is always False (runners/map_replay.py:44) - Operator precedence making `noisy=False` set all means to 0.0: `item["mean"] + noise if noisy else 0.0` parses as `(item["mean"] + noise) if noisy else 0.0`, replacing the entire mean with 0.0 when `noisy=False` (metrics/noisy_function_map.py:86, metrics/branin_map.py:119) - Misleading "total number of trials" vs actual trial index (global_stopping/strategies/improvement.py:132) + test update - Typos: "trial rial" → "trial" (fb/tutorials/auto_tuning.py), "prefererd" → "preferred" and stray `f` in f-string (fb/utils/preference/preference.py), "acquistion" → "acquisition" (generators/torch/botorch_modular/utils.py + test) Reviewed By: ItsMrLin Differential Revision: D92879402
1 parent ec6b673 commit 3aa7602

10 files changed

Lines changed: 15 additions & 31 deletions

File tree

ax/api/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1338,7 +1338,7 @@ def _from_json_snapshot(
13381338
decoder_registry=decoder_registry,
13391339
class_decoder_registry=class_decoder_registry,
13401340
)
1341-
if "generation_strategy" in snapshot
1341+
if snapshot.get("generation_strategy") is not None
13421342
else None
13431343
)
13441344

ax/benchmark/benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,9 @@ def compute_baseline_value_from_sobol(
701701
n_repeats: Number of times to repeat the five Sobol trials.
702702
"""
703703
method = get_sobol_benchmark_method()
704-
target_fidelity_and_task = {} if target_fidelity_and_task is None else {}
704+
target_fidelity_and_task = (
705+
{} if target_fidelity_and_task is None else target_fidelity_and_task
706+
)
705707

706708
# set up a dummy problem so we can use `benchmark_replication`
707709
# MOO problems are always higher-is-better because they use hypervolume

ax/early_stopping/strategies/base.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from abc import ABC, abstractmethod
11-
from collections.abc import Iterable, Sequence
11+
from collections.abc import Iterable
1212
from logging import Logger
1313
from typing import cast
1414

@@ -665,19 +665,3 @@ def _lookup_and_validate_data(
665665
full_df = full_df[full_df[MAP_KEY] >= self.min_progression_modeling]
666666
map_data = Data(df=full_df)
667667
return map_data
668-
669-
def get_training_data(
670-
self,
671-
experiment: Experiment,
672-
map_data: Data,
673-
max_training_size: int | None = None,
674-
outcomes: Sequence[str] | None = None,
675-
parameters: list[str] | None = None,
676-
) -> None:
677-
# Deprecated in Ax 1.1.0, so should be removed in Ax 1.2.0+.
678-
raise DeprecationWarning(
679-
"`ModelBasedEarlyStoppingStrategy.get_training_data` is deprecated. "
680-
"Subclasses should either extract the training data manually, "
681-
"or rely on the fitted surrogates available in the current generation "
682-
"node that is passed into `should_stop_trials_early`."
683-
)

ax/generators/torch/botorch_modular/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def construct_acquisition_and_optimizer_options(
539539
if len(botorch_acqf_classes_with_options) > 1:
540540
warnings.warn(
541541
message="botorch_acqf_options are being ignored, due to using "
542-
"MultiAcquisition. Specify options for each acquistion function"
542+
"MultiAcquisition. Specify options for each acquisition function "
543543
"via botorch_acqf_classes_with_options.",
544544
category=AxWarning,
545545
stacklevel=4,

ax/generators/torch/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ def test_construct_acquisition_and_optimizer_options(self) -> None:
693693
self.assertEqual(
694694
str(warning.message),
695695
"botorch_acqf_options are being ignored, due to using "
696-
"MultiAcquisition. Specify options for each acquistion function"
696+
"MultiAcquisition. Specify options for each acquisition function "
697697
"via botorch_acqf_classes_with_options.",
698698
)
699699

ax/global_stopping/strategies/improvement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def _should_stop_optimization(
130130
trial_to_check = max_completed_trial
131131
elif trial_to_check > max_completed_trial:
132132
raise ValueError(
133-
"trial_to_check is larger than the total number of "
134-
f"trials (={max_completed_trial})."
133+
"trial_to_check is larger than the maximum completed "
134+
f"trial index (={max_completed_trial})."
135135
)
136136

137137
# Only counting the trials up to trial_to_check.

ax/global_stopping/tests/test_strategies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_base_cases(self) -> None:
8585
# Should raise ValueError if trying to check an invalid trial
8686
with self.assertRaisesRegex(
8787
ValueError,
88-
r"trial_to_check is larger than the total number of trials \(=4\).",
88+
r"trial_to_check is larger than the maximum completed trial index \(=4\).",
8989
):
9090
stop, message = gss.should_stop_optimization(
9191
experiment=exp, trial_to_check=5

ax/metrics/branin_map.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ def fetch_trial_data(
117117
"sem": self.noise_sd if noisy else 0.0,
118118
"trial_index": trial.index,
119119
"mean": [
120-
item["mean"] + self.noise_sd * np.random.randn()
121-
if noisy
122-
else 0.0
120+
item["mean"]
121+
+ (self.noise_sd * np.random.randn() if noisy else 0.0)
123122
for item in res
124123
],
125124
"metric_signature": self.signature,

ax/metrics/noisy_function_map.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,8 @@ def fetch_trial_data(
8484
"sem": self.noise_sd if noisy else 0.0,
8585
"trial_index": trial.index,
8686
"mean": [
87-
item["mean"] + self.noise_sd * np.random.randn()
88-
if noisy
89-
else 0.0
87+
item["mean"]
88+
+ (self.noise_sd * np.random.randn() if noisy else 0.0)
9089
for item in res
9190
],
9291
"metric_signature": self.signature,

ax/runners/map_replay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def poll_trial_status(
4141
# depending on whether or not there is more data available,
4242
# mark it either RUNNING or COMPLETED.
4343
for t in trials:
44-
if not t.run_metadata.get(STARTED_KEY, "False"):
44+
if not t.run_metadata.get(STARTED_KEY, False):
4545
result[TrialStatus.CANDIDATE].add(t.index)
4646
elif not self.replay_metric.has_trial_data(t.index):
4747
result[TrialStatus.ABANDONED].add(t.index)

0 commit comments

Comments
 (0)