|
19 | 19 | import numpy as np |
20 | 20 | import torch |
21 | 21 | from ax.adapter.registry import Cont_X_trans, Generators |
| 22 | +from ax.api.configs import ChoiceParameterConfig, RangeParameterConfig |
22 | 23 | from ax.core.arm import Arm |
23 | 24 | from ax.core.data import Data, MAP_KEY |
24 | 25 | from ax.core.generator_run import GeneratorRun |
@@ -1355,6 +1356,198 @@ def test_set_search_space(self) -> None: |
1355 | 1356 | [ParameterConstraint(inequality="x1 <= x2")], |
1356 | 1357 | ) |
1357 | 1358 |
|
| 1359 | + def test_update_parameters(self) -> None: |
| 1360 | + """Test that update_parameters correctly updates parameters and raises |
| 1361 | + appropriate errors.""" |
| 1362 | + ax_client = AxClient() |
| 1363 | + ax_client.create_experiment( |
| 1364 | + name="test_experiment", |
| 1365 | + parameters=[ |
| 1366 | + { |
| 1367 | + "name": "x1", |
| 1368 | + "type": "range", |
| 1369 | + "bounds": [0.0, 1.0], |
| 1370 | + "value_type": "float", |
| 1371 | + }, |
| 1372 | + { |
| 1373 | + "name": "x2", |
| 1374 | + "type": "range", |
| 1375 | + "bounds": [1, 10], |
| 1376 | + "value_type": "int", |
| 1377 | + }, |
| 1378 | + { |
| 1379 | + "name": "x3", |
| 1380 | + "type": "choice", |
| 1381 | + "values": ["a", "b", "c"], |
| 1382 | + }, |
| 1383 | + ], |
| 1384 | + is_test=True, |
| 1385 | + immutable_search_space_and_opt_config=False, |
| 1386 | + ) |
| 1387 | + |
| 1388 | + # --- sub-test 1: update RangeParameter bounds (float) --- |
| 1389 | + with self.subTest("update_float_range_parameter"): |
| 1390 | + ax_client.update_parameters( |
| 1391 | + parameters=[ |
| 1392 | + RangeParameterConfig( |
| 1393 | + name="x1", |
| 1394 | + bounds=(0.5, 2.0), |
| 1395 | + parameter_type="float", |
| 1396 | + ), |
| 1397 | + ] |
| 1398 | + ) |
| 1399 | + param = ax_client.experiment.search_space.parameters["x1"] |
| 1400 | + self.assertIsInstance(param, RangeParameter) |
| 1401 | + assert isinstance(param, RangeParameter) |
| 1402 | + self.assertEqual(param.lower, 0.5) |
| 1403 | + self.assertEqual(param.upper, 2.0) |
| 1404 | + |
| 1405 | + # --- sub-test 2: update RangeParameter bounds (int) --- |
| 1406 | + with self.subTest("update_int_range_parameter"): |
| 1407 | + ax_client.update_parameters( |
| 1408 | + parameters=[ |
| 1409 | + RangeParameterConfig( |
| 1410 | + name="x2", |
| 1411 | + bounds=(5, 20), |
| 1412 | + parameter_type="int", |
| 1413 | + ), |
| 1414 | + ] |
| 1415 | + ) |
| 1416 | + param = ax_client.experiment.search_space.parameters["x2"] |
| 1417 | + self.assertIsInstance(param, RangeParameter) |
| 1418 | + assert isinstance(param, RangeParameter) |
| 1419 | + self.assertEqual(param.lower, 5) |
| 1420 | + self.assertEqual(param.upper, 20) |
| 1421 | + |
| 1422 | + # --- sub-test 3: raises on missing parameter --- |
| 1423 | + with self.subTest("raises_on_missing_parameter"): |
| 1424 | + with self.assertRaisesRegex( |
| 1425 | + UserInputError, "Parameter nonexistent not found in search space" |
| 1426 | + ): |
| 1427 | + ax_client.update_parameters( |
| 1428 | + parameters=[ |
| 1429 | + RangeParameterConfig( |
| 1430 | + name="nonexistent", |
| 1431 | + bounds=(0.0, 1.0), |
| 1432 | + parameter_type="float", |
| 1433 | + ), |
| 1434 | + ] |
| 1435 | + ) |
| 1436 | + |
| 1437 | + # --- sub-test 4: raises on choice parameter --- |
| 1438 | + with self.subTest("raises_on_choice_parameter"): |
| 1439 | + with self.assertRaisesRegex( |
| 1440 | + UserInputError, "Choice parameters cannot be updated" |
| 1441 | + ): |
| 1442 | + ax_client.update_parameters( |
| 1443 | + parameters=[ |
| 1444 | + ChoiceParameterConfig( |
| 1445 | + name="x3", |
| 1446 | + values=["d", "e", "f"], |
| 1447 | + parameter_type="str", |
| 1448 | + ), |
| 1449 | + ] |
| 1450 | + ) |
| 1451 | + |
| 1452 | + def test_add_parameters(self) -> None: |
| 1453 | + """Test that add_parameters correctly adds new parameters to the |
| 1454 | + search space. |
| 1455 | + """ |
| 1456 | + ax_client = AxClient() |
| 1457 | + ax_client.create_experiment( |
| 1458 | + name="test_experiment", |
| 1459 | + parameters=[ |
| 1460 | + { |
| 1461 | + "name": "x1", |
| 1462 | + "type": "range", |
| 1463 | + "bounds": [0.0, 1.0], |
| 1464 | + "value_type": "float", |
| 1465 | + }, |
| 1466 | + ], |
| 1467 | + is_test=True, |
| 1468 | + immutable_search_space_and_opt_config=False, |
| 1469 | + ) |
| 1470 | + |
| 1471 | + ax_client.add_parameters( |
| 1472 | + parameters=[ |
| 1473 | + RangeParameterConfig( |
| 1474 | + name="x2", |
| 1475 | + bounds=(0.0, 10.0), |
| 1476 | + parameter_type="float", |
| 1477 | + ), |
| 1478 | + ChoiceParameterConfig( |
| 1479 | + name="x3", |
| 1480 | + values=["a", "b", "c"], |
| 1481 | + parameter_type="str", |
| 1482 | + ), |
| 1483 | + ], |
| 1484 | + backfill_values={"x2": 5.0, "x3": "a"}, |
| 1485 | + ) |
| 1486 | + |
| 1487 | + search_space = ax_client.experiment.search_space |
| 1488 | + self.assertIn("x1", search_space.parameters) |
| 1489 | + self.assertIn("x2", search_space.parameters) |
| 1490 | + self.assertIn("x3", search_space.parameters) |
| 1491 | + |
| 1492 | + param_x2 = search_space.parameters["x2"] |
| 1493 | + self.assertIsInstance(param_x2, RangeParameter) |
| 1494 | + assert isinstance(param_x2, RangeParameter) |
| 1495 | + self.assertEqual(param_x2.lower, 0.0) |
| 1496 | + self.assertEqual(param_x2.upper, 10.0) |
| 1497 | + |
| 1498 | + param_x3 = search_space.parameters["x3"] |
| 1499 | + self.assertIsInstance(param_x3, ChoiceParameter) |
| 1500 | + assert isinstance(param_x3, ChoiceParameter) |
| 1501 | + self.assertEqual(param_x3.values, ["a", "b", "c"]) |
| 1502 | + |
| 1503 | + def test_disable_parameters(self) -> None: |
| 1504 | + """Test that disable_parameters correctly disables parameters in the search |
| 1505 | + space.""" |
| 1506 | + ax_client = AxClient() |
| 1507 | + ax_client.create_experiment( |
| 1508 | + name="test_experiment", |
| 1509 | + parameters=[ |
| 1510 | + { |
| 1511 | + "name": "x1", |
| 1512 | + "type": "range", |
| 1513 | + "bounds": [0.0, 1.0], |
| 1514 | + "value_type": "float", |
| 1515 | + }, |
| 1516 | + { |
| 1517 | + "name": "x2", |
| 1518 | + "type": "range", |
| 1519 | + "bounds": [1, 10], |
| 1520 | + "value_type": "int", |
| 1521 | + }, |
| 1522 | + { |
| 1523 | + "name": "x3", |
| 1524 | + "type": "choice", |
| 1525 | + "values": ["a", "b", "c"], |
| 1526 | + }, |
| 1527 | + ], |
| 1528 | + is_test=True, |
| 1529 | + immutable_search_space_and_opt_config=False, |
| 1530 | + ) |
| 1531 | + |
| 1532 | + ax_client.disable_parameters(default_parameter_values={"x2": 5, "x3": "b"}) |
| 1533 | + |
| 1534 | + search_space = ax_client.experiment.search_space |
| 1535 | + self.assertIn("x1", search_space.parameters) |
| 1536 | + self.assertIn("x2", search_space.parameters) |
| 1537 | + self.assertIn("x3", search_space.parameters) |
| 1538 | + |
| 1539 | + param_x1 = search_space.parameters["x1"] |
| 1540 | + self.assertIsInstance(param_x1, RangeParameter) |
| 1541 | + self.assertFalse(param_x1.is_disabled) |
| 1542 | + |
| 1543 | + param_x2 = search_space.parameters["x2"] |
| 1544 | + self.assertTrue(param_x2.is_disabled) |
| 1545 | + self.assertEqual(param_x2.default_value, 5) |
| 1546 | + |
| 1547 | + param_x3 = search_space.parameters["x3"] |
| 1548 | + self.assertTrue(param_x3.is_disabled) |
| 1549 | + self.assertEqual(param_x3.default_value, "b") |
| 1550 | + |
1358 | 1551 | def test_create_moo_experiment(self) -> None: |
1359 | 1552 | """Test basic experiment creation.""" |
1360 | 1553 | ax_client = AxClient( |
|
0 commit comments