Skip to content

Commit a88c03e

Browse files
andycylmetafacebook-github-bot
authored andcommitted
Search space editing
Summary: Currently, when updating search spaces for an experiment, the entire search space needs to be specified. There are feedbacks mentioning this is hard to use, especially when the search space has many parameters. This diff ads a method "udpate_parameters" which only updates the values for an existing range parameters in the the search space. Differential Revision: D93766951
1 parent 052a554 commit a88c03e

2 files changed

Lines changed: 300 additions & 0 deletions

File tree

ax/service/ax_client.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import pandas as pd
2020
import torch
2121
from ax.adapter.prediction_utils import predict_by_features
22+
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig
23+
from ax.api.utils.instantiation.from_config import parameter_from_config
2224
from ax.core.arm import Arm
2325
from ax.core.base_trial import BaseTrial
2426
from ax.core.evaluations_to_data import raw_evaluations_to_data
@@ -517,6 +519,111 @@ def set_search_space(
517519
experiment=self.experiment,
518520
)
519521

522+
def add_parameters(
523+
self,
524+
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
525+
backfill_values: TParameterization,
526+
status_quo_values: TParameterization | None = None,
527+
) -> None:
528+
"""
529+
Add new parameters to the experiment's search space. This allows extending
530+
the search space after the experiment has run some trials.
531+
532+
Backfill values must be provided for all new parameters to ensure existing
533+
trials in the experiment remain valid within the expanded search space. The
534+
backfill values represent the parameter values that were used in the existing
535+
trials.
536+
537+
Args:
538+
parameters: A sequence of parameter configurations to add to the search
539+
space.
540+
backfill_values: Parameter values to assign to existing trials for the
541+
new parameters being added. All new parameter names must have
542+
corresponding backfill values provided.
543+
status_quo_values: Optional parameter values for the new parameters to
544+
use in the status quo (baseline) arm, if one is defined. If None,
545+
the backfill values will be used for the status quo.
546+
"""
547+
parameters_to_add = [
548+
parameter_from_config(parameter_config) for parameter_config in parameters
549+
]
550+
parameter_names = {parameter.name for parameter in parameters_to_add}
551+
missing_backfill_values = parameter_names - backfill_values.keys()
552+
if missing_backfill_values:
553+
raise UserInputError(
554+
"You must provide backfill values for all parameters being added. "
555+
f"Missing values for parameters: {missing_backfill_values}."
556+
)
557+
extra_backfill_values = backfill_values.keys() - parameter_names
558+
if extra_backfill_values:
559+
logger.warning(
560+
"Backfill values provided for parameters not being added: "
561+
f"{extra_backfill_values}. Will ignore these values."
562+
)
563+
for parameter in parameters_to_add:
564+
if parameter.name in backfill_values:
565+
parameter._backfill_value = backfill_values[parameter.name]
566+
self.experiment.add_parameters_to_search_space(
567+
parameters=parameters_to_add,
568+
status_quo_values=status_quo_values,
569+
)
570+
self._save_experiment_to_db_if_possible(experiment=self.experiment)
571+
572+
def disable_parameters(
573+
self,
574+
default_parameter_values: TParameterization,
575+
) -> None:
576+
"""
577+
Disable parameters in the experiment. This allows narrowing the search space
578+
after the experiment has run some trials.
579+
580+
When parameters are disabled, they are effectively removed from the search
581+
space for future trial generation. Existing trials remain valid, and the
582+
disabled parameters are replaced with fixed default values for all subsequent
583+
trials.
584+
585+
Args:
586+
default_parameter_values: Fixed values to use for the disabled parameters
587+
in all future trials. These values will be used for the parameter in
588+
all subsequent trials.
589+
"""
590+
self.experiment.disable_parameters_in_search_space(
591+
default_parameter_values=default_parameter_values
592+
)
593+
self._save_experiment_to_db_if_possible(experiment=self.experiment)
594+
595+
def update_parameters(
596+
self,
597+
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
598+
) -> None:
599+
"""Update parameters in the experiment's search space.
600+
601+
This allows modifying the search space after the experiment has run some
602+
trials.
603+
604+
Args:
605+
parameters: A sequence of parameter configurations to update in the
606+
search space. Only ``RangeParameterConfig`` is supported.
607+
608+
Raises:
609+
UserInputError: If a parameter is not found in the search space or
610+
if a ``ChoiceParameterConfig`` is provided.
611+
"""
612+
search_space = self.experiment.search_space
613+
for parameter in parameters:
614+
if parameter.name not in search_space.parameters:
615+
raise UserInputError(
616+
f"Parameter {parameter.name} not found in search space."
617+
)
618+
if isinstance(parameter, ChoiceParameterConfig):
619+
raise UserInputError("Choice parameters cannot be updated.")
620+
621+
for parameter in parameters:
622+
search_space.update_parameter(parameter=parameter_from_config(parameter))
623+
self._save_experiment_to_db_if_possible(
624+
experiment=self.experiment,
625+
)
626+
520627
@retry_on_exception(
521628
logger=logger,
522629
exception_types=(RuntimeError,),

ax/service/tests/test_ax_client.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121
from ax.adapter.registry import Cont_X_trans, Generators
22+
from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig
2223
from ax.core.arm import Arm
2324
from ax.core.data import Data, MAP_KEY
2425
from ax.core.generator_run import GeneratorRun
@@ -1355,6 +1356,198 @@ def test_set_search_space(self) -> None:
13551356
[ParameterConstraint(inequality="x1 <= x2")],
13561357
)
13571358

1359+
def test_update_parameters(self) -> None:
1360+
"""Test that update_parameters correctly updates parameters and raises
1361+
appropriate errors."""
1362+
ax_client = AxClient()
1363+
ax_client.create_experiment(
1364+
name="test_experiment",
1365+
parameters=[
1366+
{
1367+
"name": "x1",
1368+
"type": "range",
1369+
"bounds": [0.0, 1.0],
1370+
"value_type": "float",
1371+
},
1372+
{
1373+
"name": "x2",
1374+
"type": "range",
1375+
"bounds": [1, 10],
1376+
"value_type": "int",
1377+
},
1378+
{
1379+
"name": "x3",
1380+
"type": "choice",
1381+
"values": ["a", "b", "c"],
1382+
},
1383+
],
1384+
is_test=True,
1385+
immutable_search_space_and_opt_config=False,
1386+
)
1387+
1388+
# --- sub-test 1: update RangeParameter bounds (float) ---
1389+
with self.subTest("update_float_range_parameter"):
1390+
ax_client.update_parameters(
1391+
parameters=[
1392+
RangeParameterConfig(
1393+
name="x1",
1394+
bounds=(0.5, 2.0),
1395+
parameter_type="float",
1396+
),
1397+
]
1398+
)
1399+
param = ax_client.experiment.search_space.parameters["x1"]
1400+
self.assertIsInstance(param, RangeParameter)
1401+
assert isinstance(param, RangeParameter)
1402+
self.assertEqual(param.lower, 0.5)
1403+
self.assertEqual(param.upper, 2.0)
1404+
1405+
# --- sub-test 2: update RangeParameter bounds (int) ---
1406+
with self.subTest("update_int_range_parameter"):
1407+
ax_client.update_parameters(
1408+
parameters=[
1409+
RangeParameterConfig(
1410+
name="x2",
1411+
bounds=(5, 20),
1412+
parameter_type="int",
1413+
),
1414+
]
1415+
)
1416+
param = ax_client.experiment.search_space.parameters["x2"]
1417+
self.assertIsInstance(param, RangeParameter)
1418+
assert isinstance(param, RangeParameter)
1419+
self.assertEqual(param.lower, 5)
1420+
self.assertEqual(param.upper, 20)
1421+
1422+
# --- sub-test 3: raises on missing parameter ---
1423+
with self.subTest("raises_on_missing_parameter"):
1424+
with self.assertRaisesRegex(
1425+
UserInputError, "Parameter nonexistent not found in search space"
1426+
):
1427+
ax_client.update_parameters(
1428+
parameters=[
1429+
RangeParameterConfig(
1430+
name="nonexistent",
1431+
bounds=(0.0, 1.0),
1432+
parameter_type="float",
1433+
),
1434+
]
1435+
)
1436+
1437+
# --- sub-test 4: raises on choice parameter ---
1438+
with self.subTest("raises_on_choice_parameter"):
1439+
with self.assertRaisesRegex(
1440+
UserInputError, "Choice parameters cannot be updated"
1441+
):
1442+
ax_client.update_parameters(
1443+
parameters=[
1444+
ChoiceParameterConfig(
1445+
name="x3",
1446+
values=["d", "e", "f"],
1447+
parameter_type="str",
1448+
),
1449+
]
1450+
)
1451+
1452+
def test_add_parameters(self) -> None:
1453+
"""Test that add_parameters correctly adds new parameters to the
1454+
search space.
1455+
"""
1456+
ax_client = AxClient()
1457+
ax_client.create_experiment(
1458+
name="test_experiment",
1459+
parameters=[
1460+
{
1461+
"name": "x1",
1462+
"type": "range",
1463+
"bounds": [0.0, 1.0],
1464+
"value_type": "float",
1465+
},
1466+
],
1467+
is_test=True,
1468+
immutable_search_space_and_opt_config=False,
1469+
)
1470+
1471+
ax_client.add_parameters(
1472+
parameters=[
1473+
RangeParameterConfig(
1474+
name="x2",
1475+
bounds=(0.0, 10.0),
1476+
parameter_type="float",
1477+
),
1478+
ChoiceParameterConfig(
1479+
name="x3",
1480+
values=["a", "b", "c"],
1481+
parameter_type="str",
1482+
),
1483+
],
1484+
backfill_values={"x2": 5.0, "x3": "a"},
1485+
)
1486+
1487+
search_space = ax_client.experiment.search_space
1488+
self.assertIn("x1", search_space.parameters)
1489+
self.assertIn("x2", search_space.parameters)
1490+
self.assertIn("x3", search_space.parameters)
1491+
1492+
param_x2 = search_space.parameters["x2"]
1493+
self.assertIsInstance(param_x2, RangeParameter)
1494+
assert isinstance(param_x2, RangeParameter)
1495+
self.assertEqual(param_x2.lower, 0.0)
1496+
self.assertEqual(param_x2.upper, 10.0)
1497+
1498+
param_x3 = search_space.parameters["x3"]
1499+
self.assertIsInstance(param_x3, ChoiceParameter)
1500+
assert isinstance(param_x3, ChoiceParameter)
1501+
self.assertEqual(param_x3.values, ["a", "b", "c"])
1502+
1503+
def test_disable_parameters(self) -> None:
1504+
"""Test that disable_parameters correctly disables parameters in the search
1505+
space."""
1506+
ax_client = AxClient()
1507+
ax_client.create_experiment(
1508+
name="test_experiment",
1509+
parameters=[
1510+
{
1511+
"name": "x1",
1512+
"type": "range",
1513+
"bounds": [0.0, 1.0],
1514+
"value_type": "float",
1515+
},
1516+
{
1517+
"name": "x2",
1518+
"type": "range",
1519+
"bounds": [1, 10],
1520+
"value_type": "int",
1521+
},
1522+
{
1523+
"name": "x3",
1524+
"type": "choice",
1525+
"values": ["a", "b", "c"],
1526+
},
1527+
],
1528+
is_test=True,
1529+
immutable_search_space_and_opt_config=False,
1530+
)
1531+
1532+
ax_client.disable_parameters(default_parameter_values={"x2": 5, "x3": "b"})
1533+
1534+
search_space = ax_client.experiment.search_space
1535+
self.assertIn("x1", search_space.parameters)
1536+
self.assertIn("x2", search_space.parameters)
1537+
self.assertIn("x3", search_space.parameters)
1538+
1539+
param_x1 = search_space.parameters["x1"]
1540+
self.assertIsInstance(param_x1, RangeParameter)
1541+
self.assertFalse(param_x1.is_disabled)
1542+
1543+
param_x2 = search_space.parameters["x2"]
1544+
self.assertTrue(param_x2.is_disabled)
1545+
self.assertEqual(param_x2.default_value, 5)
1546+
1547+
param_x3 = search_space.parameters["x3"]
1548+
self.assertTrue(param_x3.is_disabled)
1549+
self.assertEqual(param_x3.default_value, "b")
1550+
13581551
def test_create_moo_experiment(self) -> None:
13591552
"""Test basic experiment creation."""
13601553
ax_client = AxClient(

0 commit comments

Comments
 (0)