Skip to content

Commit 11c3bf8

Browse files
eonofreyfacebook-github-bot
authored andcommitted
TransferLearningAnalysis (facebook#4918)
Summary: Analysis card to show transferrable experiments with a default of 25% parameter overlap. Differential Revision: D92926519
1 parent 2c8dbe5 commit 11c3bf8

4 files changed

Lines changed: 334 additions & 0 deletions

File tree

ax/analysis/healthcheck/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
2525
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
2626
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
27+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
2728

2829
__all__ = [
2930
"create_healthcheck_analysis_card",
@@ -39,4 +40,5 @@
3940
"ComplexityRatingAnalysis",
4041
"PredictableMetricsAnalysis",
4142
"BaselineImprovementAnalysis",
43+
"TransferLearningAnalysis",
4244
]
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from unittest.mock import patch
9+
10+
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
11+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
12+
from ax.core.auxiliary import TransferLearningMetadata
13+
from ax.core.experiment import Experiment
14+
from ax.core.parameter import ParameterType, RangeParameter
15+
from ax.core.search_space import SearchSpace
16+
from ax.exceptions.core import UserInputError
17+
from ax.utils.common.testutils import TestCase
18+
19+
20+
def _make_experiment(
21+
param_names: list[str],
22+
experiment_type: str | None = None,
23+
) -> Experiment:
24+
"""Create a simple experiment with the given parameter names."""
25+
return Experiment(
26+
search_space=SearchSpace(
27+
parameters=[
28+
RangeParameter(
29+
name=name,
30+
parameter_type=ParameterType.FLOAT,
31+
lower=0.0,
32+
upper=1.0,
33+
)
34+
for name in param_names
35+
]
36+
),
37+
name="test_experiment",
38+
experiment_type=experiment_type,
39+
)
40+
41+
42+
_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments"
43+
44+
45+
class TestTransferLearningAnalysis(TestCase):
46+
def test_no_experiment_type_returns_pass(self) -> None:
47+
"""When no experiment_type is set and no experiment_types provided,
48+
return PASS."""
49+
experiment = _make_experiment(["x1", "x2"], experiment_type=None)
50+
analysis = TransferLearningAnalysis()
51+
card = analysis.compute(experiment=experiment)
52+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
53+
self.assertTrue(card.is_passing())
54+
self.assertIn("No experiment type set", card.subtitle)
55+
56+
@patch(_MOCK_TARGET, return_value={})
57+
def test_no_candidates_returns_pass(self, mock_identify: object) -> None:
58+
experiment = _make_experiment(["x1", "x2"], experiment_type="my_type")
59+
analysis = TransferLearningAnalysis()
60+
card = analysis.compute(experiment=experiment)
61+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
62+
self.assertTrue(card.is_passing())
63+
self.assertTrue(card.df.empty)
64+
65+
@patch(_MOCK_TARGET)
66+
def test_single_candidate_returns_warning(self, mock_identify: object) -> None:
67+
experiment = _make_experiment(
68+
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
69+
)
70+
mock_identify.return_value = { # pyre-ignore[16]
71+
"source_exp": TransferLearningMetadata(
72+
overlap_parameters=["x1", "x2", "x3", "x4"],
73+
),
74+
}
75+
analysis = TransferLearningAnalysis()
76+
card = analysis.compute(experiment=experiment)
77+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
78+
self.assertFalse(card.is_passing())
79+
self.assertIn("source_exp", card.subtitle)
80+
self.assertIn("80.0%", card.subtitle)
81+
self.assertEqual(len(card.df), 1)
82+
self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp")
83+
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
84+
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0)
85+
86+
@patch(_MOCK_TARGET)
87+
def test_multiple_candidates_sorted_by_count(self, mock_identify: object) -> None:
88+
experiment = _make_experiment(
89+
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
90+
)
91+
mock_identify.return_value = { # pyre-ignore[16]
92+
"exp_low": TransferLearningMetadata(
93+
overlap_parameters=["x1"],
94+
),
95+
"exp_high": TransferLearningMetadata(
96+
overlap_parameters=["x1", "x2", "x3", "x4"],
97+
),
98+
"exp_mid": TransferLearningMetadata(
99+
overlap_parameters=["x1", "x2", "x3"],
100+
),
101+
}
102+
analysis = TransferLearningAnalysis()
103+
card = analysis.compute(experiment=experiment)
104+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
105+
106+
# Verify sorted descending by overlap count
107+
self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high")
108+
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
109+
self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid")
110+
self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3)
111+
self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low")
112+
self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1)
113+
114+
# All experiments listed in subtitle
115+
self.assertIn("exp_high", card.subtitle)
116+
self.assertIn("exp_mid", card.subtitle)
117+
self.assertIn("exp_low", card.subtitle)
118+
self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle)
119+
120+
@patch(_MOCK_TARGET)
121+
def test_percentage_calculation(self, mock_identify: object) -> None:
122+
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
123+
mock_identify.return_value = { # pyre-ignore[16]
124+
"exp_a": TransferLearningMetadata(
125+
overlap_parameters=["x1"],
126+
),
127+
}
128+
analysis = TransferLearningAnalysis()
129+
card = analysis.compute(experiment=experiment)
130+
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3)
131+
132+
@patch(_MOCK_TARGET)
133+
def test_parameters_listed_alphabetically(self, mock_identify: object) -> None:
134+
experiment = _make_experiment(
135+
["alpha", "beta", "gamma", "delta"], experiment_type="my_type"
136+
)
137+
mock_identify.return_value = { # pyre-ignore[16]
138+
"exp_a": TransferLearningMetadata(
139+
overlap_parameters=["gamma", "alpha", "delta"],
140+
),
141+
}
142+
analysis = TransferLearningAnalysis()
143+
card = analysis.compute(experiment=experiment)
144+
self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma")
145+
146+
def test_requires_experiment(self) -> None:
147+
analysis = TransferLearningAnalysis()
148+
with self.assertRaises(UserInputError):
149+
analysis.compute(experiment=None)
150+
151+
@patch(_MOCK_TARGET)
152+
def test_target_experiment_filtered_out(self, mock_identify: object) -> None:
153+
"""The target experiment should be excluded from the results."""
154+
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
155+
mock_identify.return_value = { # pyre-ignore[16]
156+
"test_experiment": TransferLearningMetadata(
157+
overlap_parameters=["x1", "x2", "x3"],
158+
),
159+
"other_exp": TransferLearningMetadata(
160+
overlap_parameters=["x1"],
161+
),
162+
}
163+
analysis = TransferLearningAnalysis()
164+
card = analysis.compute(experiment=experiment)
165+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
166+
self.assertEqual(len(card.df), 1)
167+
self.assertEqual(card.df.iloc[0]["Experiment"], "other_exp")
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from __future__ import annotations
9+
10+
import json
11+
from typing import final, TYPE_CHECKING
12+
13+
import markdown as md
14+
import pandas as pd
15+
from ax.adapter.base import Adapter
16+
from ax.analysis.analysis import Analysis
17+
from ax.analysis.healthcheck.healthcheck_analysis import (
18+
create_healthcheck_analysis_card,
19+
HealthcheckAnalysisCard,
20+
HealthcheckStatus,
21+
)
22+
from ax.core.experiment import Experiment
23+
from ax.exceptions.core import UserInputError
24+
from ax.generation_strategy.generation_strategy import GenerationStrategy
25+
from pyre_extensions import override
26+
27+
if TYPE_CHECKING:
28+
from ax.storage.sqa_store.sqa_config import SQAConfig
29+
30+
31+
class TransferLearningAnalysisCard(HealthcheckAnalysisCard):
32+
"""HealthcheckAnalysisCard with markdown-aware rendering for notebooks."""
33+
34+
def _body_html(self, depth: int) -> str:
35+
parts = [md.markdown(self.subtitle)]
36+
if not self.df.empty:
37+
parts.append(self.df.to_html(index=False))
38+
return f"<div class='content'>{''.join(parts)}</div>"
39+
40+
41+
@final
42+
class TransferLearningAnalysis(Analysis):
43+
def __init__(
44+
self,
45+
experiment_types: list[str] | None = None,
46+
overlap_threshold: float = 0.25,
47+
max_num_exps: int = 10,
48+
config: SQAConfig | None = None,
49+
) -> None:
50+
self.experiment_types = experiment_types
51+
self.overlap_threshold = overlap_threshold
52+
self.max_num_exps = max_num_exps
53+
self.config = config
54+
55+
@override
56+
def compute(
57+
self,
58+
experiment: Experiment | None = None,
59+
generation_strategy: GenerationStrategy | None = None,
60+
adapter: Adapter | None = None,
61+
) -> HealthcheckAnalysisCard:
62+
if experiment is None:
63+
raise UserInputError(
64+
"TransferLearningAnalysis requires a non-null experiment to compute "
65+
"overlap percentages. Please provide an experiment."
66+
)
67+
68+
# Determine experiment types to query for.
69+
experiment_types = self.experiment_types
70+
if experiment_types is None:
71+
if experiment.experiment_type is None:
72+
return create_healthcheck_analysis_card(
73+
name=self.__class__.__name__,
74+
title="Transfer Learning Eligibility",
75+
subtitle=(
76+
"No experiment type set on this experiment. "
77+
"Cannot search for transferable experiments."
78+
),
79+
df=pd.DataFrame(),
80+
status=HealthcheckStatus.PASS,
81+
)
82+
experiment_types = [experiment.experiment_type]
83+
84+
# Lazy import to avoid circular dependency (sqa_store depends on
85+
# healthcheck_analysis).
86+
from ax.storage.sqa_store.load import identify_transferable_experiments
87+
88+
transferable_experiments = identify_transferable_experiments(
89+
search_space=experiment.search_space,
90+
experiment_types=experiment_types,
91+
overlap_threshold=self.overlap_threshold,
92+
max_num_exps=self.max_num_exps,
93+
config=self.config,
94+
)
95+
96+
# Filter out the target experiment itself from results.
97+
transferable_experiments = {
98+
name: metadata
99+
for name, metadata in transferable_experiments.items()
100+
if name != experiment.name
101+
}
102+
103+
if not transferable_experiments:
104+
return create_healthcheck_analysis_card(
105+
name=self.__class__.__name__,
106+
title="Transfer Learning Eligibility",
107+
subtitle="No eligible source experiments found for transfer learning.",
108+
df=pd.DataFrame(),
109+
status=HealthcheckStatus.PASS,
110+
)
111+
112+
total_parameters = len(experiment.search_space.parameters)
113+
114+
rows = []
115+
for exp_name, metadata in transferable_experiments.items():
116+
overlap_count = len(metadata.overlap_parameters)
117+
overlap_pct = (
118+
(overlap_count / total_parameters * 100)
119+
if total_parameters > 0
120+
else 0.0
121+
)
122+
rows.append(
123+
{
124+
"Experiment": exp_name,
125+
"Overlapping Parameters": overlap_count,
126+
"Overlap (%)": round(overlap_pct, 1),
127+
"Parameters": ", ".join(sorted(metadata.overlap_parameters)),
128+
}
129+
)
130+
131+
# Sort by overlapping parameter count descending
132+
rows.sort(key=lambda r: r["Overlapping Parameters"], reverse=True)
133+
134+
df = pd.DataFrame(rows)
135+
136+
n = len(rows)
137+
exp_lines = "\n".join(
138+
f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)"
139+
for r in rows
140+
)
141+
subtitle = (
142+
"Transfer learning can improve optimization by leveraging data "
143+
"from similar past experiments. We found "
144+
f"**{n} eligible source experiment(s)** "
145+
"for transfer learning:\n\n"
146+
f"{exp_lines}\n\n"
147+
"Caution: Only use source experiments that are closely related "
148+
"to your current experiment. "
149+
"Using data from unrelated experiments can lead to negative "
150+
"transfer, which may hurt "
151+
"optimization performance. Review the overlapping parameters "
152+
"before enabling transfer learning."
153+
)
154+
155+
return TransferLearningAnalysisCard(
156+
name=self.__class__.__name__,
157+
title="Transfer Learning Eligibility",
158+
subtitle=subtitle,
159+
df=df,
160+
blob=json.dumps({"status": HealthcheckStatus.WARNING}),
161+
)

ax/analysis/overview.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
2525
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
2626
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
27+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
2728
from ax.analysis.insights import InsightsAnalysis
2829
from ax.analysis.results import ResultsAnalysis
2930
from ax.analysis.trials import AllTrialsAnalysis
@@ -114,6 +115,7 @@ def __init__(
114115
options: OrchestratorOptions | None = None,
115116
tier_metadata: dict[str, Any] | None = None,
116117
model_fit_threshold: float | None = None,
118+
sqa_config: Any = None,
117119
) -> None:
118120
super().__init__()
119121
self.can_generate = can_generate
@@ -124,6 +126,7 @@ def __init__(
124126
self.options = options
125127
self.tier_metadata = tier_metadata
126128
self.model_fit_threshold = model_fit_threshold
129+
self.sqa_config = sqa_config
127130

128131
@override
129132
def validate_applicable_state(
@@ -229,6 +232,7 @@ def compute(
229232
if not has_batch_trials
230233
else None,
231234
BaselineImprovementAnalysis() if not has_batch_trials else None,
235+
TransferLearningAnalysis(config=self.sqa_config),
232236
*[
233237
SearchSpaceAnalysis(trial_index=trial.index)
234238
for trial in candidate_trials

0 commit comments

Comments
 (0)