Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions ax/analysis/plotly/parallel_coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ParallelCoordinatesPlot(Analysis):
- **PARAMETER_NAME: The value of said parameter for the arm, for each parameter
"""

def __init__(self, metric_name: str | None = None) -> None:
def __init__(self, metric_name: str | None = None, parameters_names: list[str] | None = None) -> None:
"""
Args:
metric_name: The name of the metric to plot. If not specified the objective
Expand All @@ -49,6 +49,7 @@ def __init__(self, metric_name: str | None = None) -> None:
"""

self.metric_name = metric_name
self.parameters_names = parameters_names

@override
def validate_applicable_state(
Expand Down Expand Up @@ -77,7 +78,7 @@ def compute(

metric_name = self.metric_name or select_metric(experiment=experiment)

df = _prepare_data(experiment=experiment, metric=metric_name)
df = _prepare_data(experiment=experiment, metric=metric_name, parameters_names=self.parameters_names)
fig = _prepare_plot(df=df, metric_name=metric_name)

return create_plotly_analysis_card(
Expand All @@ -99,7 +100,7 @@ def compute(
)


def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame:
def _prepare_data(experiment: Experiment, metric: str, parameters_names: list[str] | None = None) -> pd.DataFrame:
data_df = experiment.lookup_data().df
filtered_df = data_df.loc[data_df["metric_name"] == metric]

Expand All @@ -111,7 +112,7 @@ def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame:
{
"trial_index": trial.index,
"arm_name": arm.name,
**arm.parameters,
**(arm.parameters if parameters_names is None else {k: arm.parameters[k] for k in parameters_names}),
metric: _find_mean(
df=filtered_df, trial_index=trial.index, arm_name=arm.name
),
Expand Down