Skip to content

Commit e90fb76

Browse files
mgrange1998facebook-github-bot
authored andcommitted
Rename parallelism to concurrency in Client and AxClient APIs
Summary: Renames the `parallelism` parameter to `concurrency` in `Client.run_trials()` and adds backward-compatible deprecated `max_parallelism` parameters in `AxClient.create_experiment()` and `AxClient.get_max_parallelism()` → `get_max_concurrency()`. Both include deprecation warnings guiding callers to use the new parameter names, with validation that old and new parameters are not specified simultaneously. Differential Revision: D93771849
1 parent eab1b83 commit e90fb76

3 files changed

Lines changed: 55 additions & 29 deletions

File tree

ax/api/client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-strict
77

88
import json
9+
import warnings
910
from collections.abc import Iterable, Sequence
1011
from logging import Logger
1112
from typing import Any, Literal, Self
@@ -43,7 +44,7 @@
4344
BaseEarlyStoppingStrategy,
4445
PercentileEarlyStoppingStrategy,
4546
)
46-
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError
47+
from ax.exceptions.core import ObjectNotFoundError, UnsupportedError, UserInputError
4748
from ax.generation_strategy.generation_strategy import GenerationStrategy
4849
from ax.orchestration.orchestrator import Orchestrator, OrchestratorOptions
4950
from ax.service.utils.best_point_mixin import BestPointMixin
@@ -710,9 +711,11 @@ def mark_trial_early_stopped(self, trial_index: int) -> None:
710711
def run_trials(
711712
self,
712713
max_trials: int,
713-
parallelism: int = 1,
714+
concurrency: int = 1,
714715
tolerated_trial_failure_rate: float = 0.5,
715716
initial_seconds_between_polls: int = 1,
717+
# Deprecated argument for backwards compatibility.
718+
parallelism: int | None = None,
716719
) -> None:
717720
"""
718721
Run maximum_trials trials in a loop by creating an ephemeral Orchestrator under
@@ -721,12 +724,25 @@ def run_trials(
721724
722725
Saves to database on completion if ``storage_config`` is present.
723726
"""
727+
# Handle deprecated `parallelism` argument.
728+
if parallelism is not None:
729+
warnings.warn(
730+
"`parallelism` is deprecated and will be removed in Ax 1.4. "
731+
"Use `concurrency` instead.",
732+
DeprecationWarning,
733+
stacklevel=2,
734+
)
735+
if concurrency != 1:
736+
raise UserInputError(
737+
"Cannot specify both `parallelism` and `concurrency`."
738+
)
739+
concurrency = parallelism
724740

725741
orchestrator = Orchestrator(
726742
experiment=self._experiment,
727743
generation_strategy=self._generation_strategy_or_choose(),
728744
options=OrchestratorOptions(
729-
max_pending_trials=parallelism,
745+
max_pending_trials=concurrency,
730746
tolerated_trial_failure_rate=tolerated_trial_failure_rate,
731747
init_seconds_between_polls=initial_seconds_between_polls,
732748
),

ax/service/ax_client.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -836,39 +836,39 @@ def get_trials_data_frame(self) -> pd.DataFrame:
836836
"""
837837
return self.experiment.to_df()
838838

839-
def get_max_parallelism(self) -> list[tuple[int, int]]:
840-
"""Retrieves maximum number of trials that can be scheduled in parallel
839+
def get_max_concurrency(self) -> list[tuple[int, int]]:
840+
"""Retrieves maximum number of trials that can be scheduled concurrently
841841
at different stages of optimization.
842842
843843
Some optimization algorithms profit significantly from sequential
844844
optimization (i.e. suggest a few points, get updated with data for them,
845845
repeat, see https://ax.dev/docs/bayesopt.html).
846-
Parallelism setting indicates how many trials should be running simulteneously
846+
Concurrency setting indicates how many trials should be running simultaneously
847847
(generated, but not yet completed with data).
848848
849849
The output of this method is mapping of form
850-
{num_trials -> max_parallelism_setting}, where the max_parallelism_setting
851-
is used for num_trials trials. If max_parallelism_setting is -1, as
852-
many of the trials can be ran in parallel, as necessary. If num_trials
853-
in a tuple is -1, then the corresponding max_parallelism_setting
850+
{num_trials -> max_concurrency_setting}, where the max_concurrency_setting
851+
is used for num_trials trials. If max_concurrency_setting is -1, as
852+
many of the trials can be ran concurrently, as necessary. If num_trials
853+
in a tuple is -1, then the corresponding max_concurrency_setting
854854
should be used for all subsequent trials.
855855
856856
For example, if the returned list is [(5, -1), (12, 6), (-1, 3)],
857-
the schedule could be: run 5 trials with any parallelism, run 6 trials in
858-
parallel twice, run 3 trials in parallel for as long as needed. Here,
857+
the schedule could be: run 5 trials with any concurrency, run 6 trials
858+
concurrently twice, run 3 trials concurrently for as long as needed. Here,
859859
'running' a trial means obtaining a next trial from `AxClient` through
860860
get_next_trials and completing it with data when available.
861861
862862
Returns:
863-
Mapping of form {num_trials -> max_parallelism_setting}.
863+
Mapping of form {num_trials -> max_concurrency_setting}.
864864
"""
865-
parallelism_settings = []
865+
concurrency_settings = []
866866
for node in self.generation_strategy._nodes:
867-
# Extract max_parallelism from MaxGenerationParallelism criterion
868-
max_parallelism = None
867+
# Extract max_concurrency from MaxGenerationParallelism criterion
868+
max_concurrency = None
869869
for tc in node.transition_criteria:
870870
if isinstance(tc, MaxGenerationParallelism):
871-
max_parallelism = tc.threshold
871+
max_concurrency = tc.threshold
872872
break
873873
# Try to get num_trials from the node. If there's no MinTrials
874874
# criterion (unlimited trials), num_trials will raise UserInputError.
@@ -877,13 +877,23 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
877877
num_trials = node.num_trials
878878
except UserInputError:
879879
num_trials = -1
880-
parallelism_settings.append(
880+
concurrency_settings.append(
881881
(
882882
num_trials,
883-
max_parallelism if max_parallelism is not None else num_trials,
883+
max_concurrency if max_concurrency is not None else num_trials,
884884
)
885885
)
886-
return parallelism_settings
886+
return concurrency_settings
887+
888+
def get_max_parallelism(self) -> list[tuple[int, int]]:
889+
"""Deprecated. Use `get_max_concurrency` instead."""
890+
warnings.warn(
891+
"`get_max_parallelism` is deprecated and will be removed in Ax 1.4. "
892+
"Use `get_max_concurrency` instead.",
893+
DeprecationWarning,
894+
stacklevel=2,
895+
)
896+
return self.get_max_concurrency()
887897

888898
def get_optimization_trace(
889899
self, objective_optimum: float | None = None

ax/service/tests/test_ax_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
UserInputError,
5151
)
5252
from ax.exceptions.generation_strategy import MaxParallelismReachedException
53-
from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM
53+
from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_CONCURRENCY
5454
from ax.generation_strategy.generation_strategy import (
5555
GenerationNode,
5656
GenerationStep,
@@ -511,7 +511,7 @@ def test_default_generation_strategy_continuous(self) -> None:
511511
if i < 5:
512512
self.assertEqual(gen_limit, 5 - i)
513513
else:
514-
self.assertEqual(gen_limit, DEFAULT_BAYESIAN_PARALLELISM)
514+
self.assertEqual(gen_limit, DEFAULT_BAYESIAN_CONCURRENCY)
515515
parameterization, trial_index = ax_client.get_next_trial()
516516
x, y = parameterization.get("x"), parameterization.get("y")
517517
ax_client.complete_trial(
@@ -1616,14 +1616,14 @@ def test_keep_generating_without_data(self) -> None:
16161616
self.assertTrue(len(node0_min_trials) > 0)
16171617
self.assertFalse(node0_min_trials[0].block_gen_if_met)
16181618

1619-
# Check that max_parallelism is None by verifying no MaxGenerationParallelism
1619+
# Check that max_concurrency is None by verifying no MaxGenerationParallelism
16201620
# criterion exists on node 1
1621-
node1_max_parallelism = [
1621+
node1_max_concurrency = [
16221622
tc
16231623
for tc in ax_client.generation_strategy._nodes[1].transition_criteria
16241624
if isinstance(tc, MaxGenerationParallelism)
16251625
]
1626-
self.assertEqual(len(node1_max_parallelism), 0)
1626+
self.assertEqual(len(node1_max_concurrency), 0)
16271627

16281628
for _ in range(10):
16291629
ax_client.get_next_trial()
@@ -1939,17 +1939,17 @@ def test_relative_oc_without_sq(self) -> None:
19391939
def test_recommended_parallelism(self) -> None:
19401940
ax_client = AxClient()
19411941
with self.assertRaisesRegex(AssertionError, "No generation strategy"):
1942-
ax_client.get_max_parallelism()
1942+
ax_client.get_max_concurrency()
19431943
ax_client.create_experiment(
19441944
parameters=[
19451945
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
19461946
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
19471947
],
19481948
)
1949-
self.assertEqual(ax_client.get_max_parallelism(), [(5, 5), (-1, 3)])
1949+
self.assertEqual(ax_client.get_max_concurrency(), [(5, 5), (-1, 3)])
19501950
self.assertEqual(
19511951
run_trials_using_recommended_parallelism(
1952-
ax_client, ax_client.get_max_parallelism(), 20
1952+
ax_client, ax_client.get_max_concurrency(), 20
19531953
),
19541954
0,
19551955
)
@@ -2872,7 +2872,7 @@ def test_estimate_early_stopping_savings(self) -> None:
28722872

28732873
self.assertEqual(ax_client.estimate_early_stopping_savings(), 0)
28742874

2875-
def test_max_parallelism_exception_when_early_stopping(self) -> None:
2875+
def test_max_concurrency_exception_when_early_stopping(self) -> None:
28762876
ax_client = AxClient()
28772877
ax_client.create_experiment(
28782878
parameters=[

0 commit comments

Comments
 (0)