|
27 | 27 | from ax.core.multi_type_experiment import ( |
28 | 28 | filter_trials_by_type, |
29 | 29 | get_trial_indices_for_statuses, |
30 | | - MultiTypeExperiment, |
31 | 30 | ) |
32 | 31 | from ax.core.runner import Runner |
33 | 32 | from ax.core.trial import Trial |
@@ -367,17 +366,11 @@ def options(self, options: OrchestratorOptions) -> None: |
367 | 366 | def trial_type(self) -> str: |
368 | 367 | """Trial type for the experiment this Orchestrator is running. |
369 | 368 |
|
370 | | - This returns None if the experiment is not a MultitypeExperiment |
371 | | -
|
372 | 369 | Returns: |
373 | | - Trial type for the experiment this Orchestrator is running if the |
374 | | - experiment is a MultiTypeExperiment and None otherwise. |
| 370 | + Trial type for the experiment this Orchestrator is running. |
| 371 | + Defaults to Keys.DEFAULT_TRIAL_TYPE if not specified. |
375 | 372 | """ |
376 | | - if isinstance(self.experiment, MultiTypeExperiment): |
377 | | - return ( |
378 | | - self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value |
379 | | - ) |
380 | | - return Keys.DEFAULT_TRIAL_TYPE.value |
| 373 | + return self.options.mt_experiment_trial_type or Keys.DEFAULT_TRIAL_TYPE.value |
381 | 374 |
|
382 | 375 | @property |
383 | 376 | def running_trials(self) -> list[BaseTrial]: |
@@ -1619,23 +1612,14 @@ def _validate_options(self, options: OrchestratorOptions) -> None: |
1619 | 1612 | "will be unable to fetch intermediate results with which to " |
1620 | 1613 | "evaluate early stopping criteria." |
1621 | 1614 | ) |
1622 | | - if isinstance(self.experiment, MultiTypeExperiment): |
1623 | | - if options.mt_experiment_trial_type is None: |
1624 | | - raise UserInputError( |
1625 | | - "Must specify `mt_experiment_trial_type` for MultiTypeExperiment." |
1626 | | - ) |
| 1615 | + if options.mt_experiment_trial_type is not None: |
1627 | 1616 | if not self.experiment.supports_trial_type( |
1628 | 1617 | options.mt_experiment_trial_type |
1629 | 1618 | ): |
1630 | 1619 | raise ValueError( |
1631 | 1620 | "Experiment does not support trial type " |
1632 | 1621 | f"{options.mt_experiment_trial_type}." |
1633 | 1622 | ) |
1634 | | - elif options.mt_experiment_trial_type is not None: |
1635 | | - raise UserInputError( |
1636 | | - "`mt_experiment_trial_type` must be None unless the experiment is a " |
1637 | | - "MultiTypeExperiment." |
1638 | | - ) |
1639 | 1623 |
|
1640 | 1624 | def _get_max_pending_trials(self) -> int: |
1641 | 1625 | """Returns the maximum number of pending trials specified in the options, or |
|
0 commit comments