Skip to content

Commit 5c6946d

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Rename get_max_parallelism to get_max_concurrency in AxClient (facebook#4923)
Summary: Renames `AxClient.get_max_parallelism()` to `get_max_concurrency()` and updates internal variable names, comments, and docstrings to use "concurrency" terminology. The old `get_max_parallelism` is preserved as a deprecated stub raising `NotImplementedError`. Also updates `get_recommended_max_parallelism` to point to the new name, and imports `MaxParallelismReachedException` / `MaxGenerationParallelism` under concurrency-named aliases. `get_max_parallelism` is only used directly in ad-hoc notebooks, making this a low-risk rename Differential Revision: D93771849
1 parent c2f9aec commit 5c6946d

2 files changed

Lines changed: 47 additions & 36 deletions

File tree

ax/service/ax_client.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@
4242
UnsupportedPlotError,
4343
UserInputError,
4444
)
45-
from ax.exceptions.generation_strategy import MaxParallelismReachedException
45+
from ax.exceptions.generation_strategy import (
46+
MaxParallelismReachedException as MaxConcurrencyReachedException,
47+
)
4648
from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy
4749
from ax.generation_strategy.generation_strategy import GenerationStrategy
48-
from ax.generation_strategy.transition_criterion import MaxGenerationParallelism
50+
from ax.generation_strategy.transition_criterion import (
51+
MaxGenerationParallelism as MaxGenerationConcurrency,
52+
)
4953
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
5054
from ax.global_stopping.strategies.improvement import constraint_satisfaction
5155
from ax.plot.base import AxPlotConfig
@@ -570,7 +574,7 @@ def get_next_trial(
570574
),
571575
ttl_seconds=ttl_seconds,
572576
)
573-
except MaxParallelismReachedException as e:
577+
except MaxConcurrencyReachedException as e:
574578
if self._early_stopping_strategy is not None:
575579
e.message += ( # noqa: B306
576580
" When stopping trials early, make sure to call `stop_trial_early` "
@@ -836,39 +840,39 @@ def get_trials_data_frame(self) -> pd.DataFrame:
836840
"""
837841
return self.experiment.to_df()
838842

839-
def get_max_parallelism(self) -> list[tuple[int, int]]:
840-
"""Retrieves maximum number of trials that can be scheduled in parallel
843+
def get_max_concurrency(self) -> list[tuple[int, int]]:
844+
"""Retrieves maximum number of trials that can be scheduled concurrently
841845
at different stages of optimization.
842846
843847
Some optimization algorithms profit significantly from sequential
844848
optimization (i.e. suggest a few points, get updated with data for them,
845849
repeat, see https://ax.dev/docs/bayesopt.html).
846-
Parallelism setting indicates how many trials should be running simulteneously
850+
Concurrency setting indicates how many trials should be running simultaneously
847851
(generated, but not yet completed with data).
848852
849853
The output of this method is mapping of form
850-
{num_trials -> max_parallelism_setting}, where the max_parallelism_setting
851-
is used for num_trials trials. If max_parallelism_setting is -1, as
852-
many of the trials can be ran in parallel, as necessary. If num_trials
853-
in a tuple is -1, then the corresponding max_parallelism_setting
854+
{num_trials -> max_concurrency_setting}, where the max_concurrency_setting
855+
is used for num_trials trials. If max_concurrency_setting is -1, as
856+
many of the trials can be ran concurrently, as necessary. If num_trials
857+
in a tuple is -1, then the corresponding max_concurrency_setting
854858
should be used for all subsequent trials.
855859
856860
For example, if the returned list is [(5, -1), (12, 6), (-1, 3)],
857-
the schedule could be: run 5 trials with any parallelism, run 6 trials in
858-
parallel twice, run 3 trials in parallel for as long as needed. Here,
861+
the schedule could be: run 5 trials with any concurrency, run 6 trials
862+
concurrently twice, run 3 trials concurrently for as long as needed. Here,
859863
'running' a trial means obtaining a next trial from `AxClient` through
860864
get_next_trials and completing it with data when available.
861865
862866
Returns:
863-
Mapping of form {num_trials -> max_parallelism_setting}.
867+
Mapping of form {num_trials -> max_concurrency_setting}.
864868
"""
865-
parallelism_settings = []
869+
concurrency_settings = []
866870
for node in self.generation_strategy._nodes:
867-
# Extract max_parallelism from MaxGenerationParallelism criterion
868-
max_parallelism = None
871+
# Extract max_concurrency from MaxGenerationConcurrency criterion
872+
max_concurrency = None
869873
for tc in node.transition_criteria:
870-
if isinstance(tc, MaxGenerationParallelism):
871-
max_parallelism = tc.threshold
874+
if isinstance(tc, MaxGenerationConcurrency):
875+
max_concurrency = tc.threshold
872876
break
873877
# Try to get num_trials from the node. If there's no MinTrials
874878
# criterion (unlimited trials), num_trials will raise UserInputError.
@@ -877,13 +881,16 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
877881
num_trials = node.num_trials
878882
except UserInputError:
879883
num_trials = -1
880-
parallelism_settings.append(
884+
concurrency_settings.append(
881885
(
882886
num_trials,
883-
max_parallelism if max_parallelism is not None else num_trials,
887+
max_concurrency if max_concurrency is not None else num_trials,
884888
)
885889
)
886-
return parallelism_settings
890+
return concurrency_settings
891+
892+
def get_max_parallelism(self) -> list[tuple[int, int]]:
893+
raise NotImplementedError("Use `get_max_concurrency` instead.")
887894

888895
def get_optimization_trace(
889896
self, objective_optimum: float | None = None
@@ -1702,8 +1709,8 @@ def __repr__(self) -> str:
17021709
@staticmethod
17031710
def get_recommended_max_parallelism() -> None:
17041711
raise NotImplementedError(
1705-
"Use `get_max_parallelism` instead; parallelism levels are now "
1706-
"enforced in generation strategy, so max parallelism is no longer "
1712+
"Use `get_max_concurrency` instead; concurrency levels are now "
1713+
"enforced in generation strategy, so max concurrency is no longer "
17071714
"just recommended."
17081715
)
17091716

ax/service/tests/test_ax_client.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
UnsupportedPlotError,
5050
UserInputError,
5151
)
52-
from ax.exceptions.generation_strategy import MaxParallelismReachedException
52+
from ax.exceptions.generation_strategy import (
53+
MaxParallelismReachedException as MaxConcurrencyReachedException,
54+
)
5355
from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY
5456
from ax.generation_strategy.generation_strategy import (
5557
GenerationNode,
@@ -58,7 +60,7 @@
5860
)
5961
from ax.generation_strategy.generator_spec import GeneratorSpec
6062
from ax.generation_strategy.transition_criterion import (
61-
MaxGenerationParallelism,
63+
MaxGenerationParallelism as MaxGenerationConcurrency,
6264
MinTrials,
6365
)
6466
from ax.metrics.branin import branin, BraninMetric
@@ -1616,14 +1618,14 @@ def test_keep_generating_without_data(self) -> None:
16161618
self.assertTrue(len(node0_min_trials) > 0)
16171619
self.assertFalse(node0_min_trials[0].block_gen_if_met)
16181620

1619-
# Check that max_parallelism is None by verifying no MaxGenerationParallelism
1621+
# Check that max_concurrency is None by verifying no MaxGenerationConcurrency
16201622
# criterion exists on node 1
1621-
node1_max_parallelism = [
1623+
node1_max_concurrency = [
16221624
tc
16231625
for tc in ax_client.generation_strategy._nodes[1].transition_criteria
1624-
if isinstance(tc, MaxGenerationParallelism)
1626+
if isinstance(tc, MaxGenerationConcurrency)
16251627
]
1626-
self.assertEqual(len(node1_max_parallelism), 0)
1628+
self.assertEqual(len(node1_max_concurrency), 0)
16271629

16281630
for _ in range(10):
16291631
ax_client.get_next_trial()
@@ -1939,17 +1941,17 @@ def test_relative_oc_without_sq(self) -> None:
19391941
def test_recommended_parallelism(self) -> None:
19401942
ax_client = AxClient()
19411943
with self.assertRaisesRegex(AssertionError, "No generation strategy"):
1942-
ax_client.get_max_parallelism()
1944+
ax_client.get_max_concurrency()
19431945
ax_client.create_experiment(
19441946
parameters=[
19451947
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
19461948
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
19471949
],
19481950
)
1949-
self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)])
1951+
self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)])
19501952
self.assertEqual(
19511953
run_trials_using_recommended_parallelism(
1952-
ax_client, ax_client.get_max_parallelism(), 20
1954+
ax_client, ax_client.get_max_concurrency(), 20
19531955
),
19541956
0,
19551957
)
@@ -2320,6 +2322,8 @@ def test_deprecated_save_load_method_errors(self) -> None:
23202322
ax_client.load_experiment("test_experiment")
23212323
with self.assertRaises(NotImplementedError):
23222324
ax_client.get_recommended_max_parallelism()
2325+
with self.assertRaises(NotImplementedError):
2326+
ax_client.get_max_parallelism()
23232327

23242328
def test_find_last_trial_with_parameterization(self) -> None:
23252329
ax_client = AxClient()
@@ -2872,7 +2876,7 @@ def test_estimate_early_stopping_savings(self) -> None:
28722876

28732877
self.assertEqual(ax_client.estimate_early_stopping_savings(), 0)
28742878

2875-
def test_max_parallelism_exception_when_early_stopping(self) -> None:
2879+
def test_max_concurrency_exception_when_early_stopping(self) -> None:
28762880
ax_client = AxClient()
28772881
ax_client.create_experiment(
28782882
parameters=[
@@ -2882,7 +2886,7 @@ def test_max_parallelism_exception_when_early_stopping(self) -> None:
28822886
support_intermediate_data=True,
28832887
)
28842888

2885-
exception = MaxParallelismReachedException(step_index=1, num_running=10)
2889+
exception = MaxConcurrencyReachedException(step_index=1, num_running=10)
28862890

28872891
# pyre-fixme[53]: Captured variable `exception` is not annotated.
28882892
def fake_new_trial(*args: Any, **kwargs: Any) -> None:
@@ -2892,15 +2896,15 @@ def fake_new_trial(*args: Any, **kwargs: Any) -> None:
28922896
ax_client.experiment.new_trial = fake_new_trial
28932897

28942898
# Without early stopping.
2895-
with self.assertRaises(MaxParallelismReachedException) as cm:
2899+
with self.assertRaises(MaxConcurrencyReachedException) as cm:
28962900
ax_client.get_next_trial()
28972901
# Assert Exception's message is unchanged.
28982902
self.assertEqual(cm.exception.message, exception.message)
28992903

29002904
# With early stopping.
29012905
ax_client._early_stopping_strategy = DummyEarlyStoppingStrategy()
29022906
# Assert Exception's message is augmented to mention early stopping.
2903-
with self.assertRaisesRegex(MaxParallelismReachedException, ".*early.*stop"):
2907+
with self.assertRaisesRegex(MaxConcurrencyReachedException, ".*early.*stop"):
29042908
ax_client.get_next_trial()
29052909

29062910
def test_experiment_does_not_support_early_stopping(self) -> None:

0 commit comments

Comments
 (0)