Skip to content

Commit b5fcf42

Browse files
committed
feat: Add TabPFNClassifier integration
- Implemented `TabPFNClassifierClass` wrapper with scikit-learn compatibility. - Registered `TabPFNClassifierClass` in pipeline and updated YAML configs. - Added unit tests for TabPFN integration. - Fixed logging issue in `project_score_save.py` by filtering large parameter objects. - Updated `.gitignore` to exclude TabPFN model checkpoints.
1 parent bb1dd58 commit b5fcf42

7 files changed

Lines changed: 404 additions & 1 deletion

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,4 @@ notebooks/percent_missing_synthetic_data_generated.pkl
112112
percent_missing_synthetic_data_generated.pkl
113113
synthetic_data_generated.csv
114114
synthetic_data_generated.csv
115+
notebooks/tabpfn-v2.5-classifier-v2.5_default-2.ckpt

config_hyperopt.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ models:
6262
# Set the following to true if a GPU is available and configured
6363
kerasClassifier_class: false
6464
knn__gpu_wrapper_class: false
65+
TabPFNClassifierClass: false # requires hf token and agreement
66+
6567

6668
# This section defines the parameter search space for Hyperopt.
6769
# The structure uses lists of options, which will be parsed into hp.choice.

config_single_run.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ models:
4646
H2O_StackedEnsemble_class: false
4747
H2O_GAM_class: true
4848
knn__gpu_wrapper_class: false
49+
TabPFNClassifierClass: true # requires hf token and agreement
4950

5051
# This section defines a single set of parameters for a standalone run.
5152
run_params:
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""Defines the TabPFN Classifier model class."""
2+
3+
import logging
4+
from typing import Any, Dict, Optional
5+
6+
import numpy as np
7+
import pandas as pd
8+
import torch
9+
from sklearn.base import BaseEstimator, ClassifierMixin
10+
from skopt.space import Categorical, Integer, Real
11+
12+
from ml_grid.util import param_space
13+
from ml_grid.util.global_params import global_parameters
14+
15+
try:
16+
from tabpfn import TabPFNClassifier
17+
from tabpfn.constants import ModelVersion
18+
TABPFN_AVAILABLE = True
19+
except ImportError:
20+
TABPFN_AVAILABLE = False
21+
logging.getLogger("ml_grid").warning(
22+
"TabPFN not available. Install with: pip install tabpfn"
23+
)
24+
25+
logging.getLogger("ml_grid").debug("Imported TabPFNClassifier class")
26+
27+
28+
class TabPFNClassifierClass(BaseEstimator, ClassifierMixin):
29+
"""TabPFN Classifier with support for hyperparameter tuning.
30+
31+
TabPFN is a foundation model for tabular data that performs well on small
32+
to medium-sized datasets (up to 50,000 rows). It requires GPU for optimal
33+
performance on datasets larger than ~1000 samples.
34+
35+
Note: TabPFN-2.5 model weights require accepting license terms at:
36+
https://huggingface.co/Prior-Labs/tabpfn_2_5
37+
"""
38+
39+
def __init__(
40+
self,
41+
parameter_space_size: Optional[str] = None,
42+
# Hyperparameters for scikit-learn compatibility
43+
model_version: str = "v2.5_default",
44+
device: str = "cpu",
45+
n_estimators: int = 4,
46+
subsample_samples: Optional[int] = None,
47+
random_state: int = 42,
48+
):
49+
"""Initializes the TabPFNClassifierClass.
50+
51+
Args:
52+
parameter_space_size (Optional[str]): Size of the parameter space for
53+
optimization. Defaults to None.
54+
55+
model_version (str): The version of the TabPFN model to use.
56+
device (str): The device to run the model on ('cpu' or 'cuda').
57+
n_estimators (int): Number of ensemble members.
58+
subsample_samples (Optional[int]): Subsample size for large datasets.
59+
random_state (int): Random state for reproducibility.
60+
Raises:
61+
ImportError: If TabPFN is not installed.
62+
"""
63+
if not TABPFN_AVAILABLE:
64+
raise ImportError(
65+
"TabPFN is not installed. Install with: pip install tabpfn"
66+
)
67+
68+
# Store scikit-learn hyperparameters
69+
self.model_version = model_version
70+
self.device = device
71+
self.n_estimators = n_estimators
72+
self.subsample_samples = subsample_samples
73+
self.random_state = random_state
74+
75+
global_params = global_parameters
76+
self.parameter_space_size = parameter_space_size
77+
78+
self.algorithm_implementation = self # The instance itself is the estimator
79+
self.method_name: str = "TabPFNClassifier"
80+
81+
self.parameter_vector_space: param_space.ParamSpace = param_space.ParamSpace(
82+
parameter_space_size
83+
)
84+
self.parameter_space: Dict[str, Any]
85+
86+
if global_params.bayessearch:
87+
self.parameter_space = {
88+
# Model version selection
89+
"model_version": Categorical([
90+
"v2.5_default", # Default: finetuned on real data
91+
"v2.5_synthetic", # Trained on synthetic data only
92+
"v2" # TabPFN v2
93+
]),
94+
95+
# Device selection - can be optimized based on availability
96+
"device": Categorical(["cuda", "cpu"]),
97+
98+
# Number of ensemble members (more = better but slower)
99+
"n_estimators": Integer(1, 8),
100+
101+
# Training subsample size (for large datasets)
102+
"subsample_samples": Categorical([None, 5000, 10000, 20000]),
103+
104+
# Random state for reproducibility
105+
"random_state": Categorical([42]),
106+
}
107+
108+
else:
109+
self.parameter_space = {
110+
"model_version": ["v2.5_default", "v2.5_synthetic", "v2"],
111+
"device": ["cuda", "cpu"],
112+
"n_estimators": [1, 2, 4, 8],
113+
"subsample_samples": [None, 5000, 10000, 20000],
114+
"random_state": [42],
115+
}
116+
117+
def fit(self, X: pd.DataFrame, y: pd.Series):
118+
"""Fits the TabPFN model.
119+
120+
This method uses the hyperparameters set on the instance to create
121+
and fit the underlying TabPFNClassifier.
122+
"""
123+
# Apply subsampling if configured
124+
if self.subsample_samples is not None and len(X) > self.subsample_samples:
125+
# Use numpy for stable random sampling
126+
rng = np.random.RandomState(self.random_state)
127+
indices = rng.choice(len(X), self.subsample_samples, replace=False)
128+
129+
# Handle DataFrame/Series or numpy arrays
130+
if isinstance(X, pd.DataFrame):
131+
X = X.iloc[indices]
132+
else:
133+
X = X[indices]
134+
135+
if isinstance(y, pd.Series):
136+
y = y.iloc[indices]
137+
else:
138+
y = y[indices]
139+
140+
# Get the hyperparameters from the instance itself
141+
params = self.get_params()
142+
143+
# Check for GPU availability and fallback if necessary
144+
if params.get("device") == "cuda" and not torch.cuda.is_available():
145+
logging.getLogger("ml_grid").warning(
146+
"TabPFN device set to 'cuda' but no CUDA GPU found. Falling back to 'cpu'."
147+
)
148+
params["device"] = "cpu"
149+
150+
# This logic was originally in create_model
151+
model_version = params.pop("model_version", "v2.5_default")
152+
153+
# Filter out non-TabPFN params that might be in get_params()
154+
valid_tabpfn_params = [
155+
"device", "n_estimators",
156+
"random_state"
157+
]
158+
params_copy = {k: v for k, v in params.items() if k in valid_tabpfn_params}
159+
160+
if model_version == "v2.5_synthetic":
161+
params_copy["model_path"] = "tabpfn-v2.5-classifier-v2.5_default-2.ckpt"
162+
163+
if model_version == "v2":
164+
self._estimator = TabPFNClassifier.create_default_for_version(
165+
ModelVersion.V2, **params_copy
166+
)
167+
else:
168+
self._estimator = TabPFNClassifier(**params_copy)
169+
170+
self._estimator.fit(X, y)
171+
self.classes_ = self._estimator.classes_
172+
return self
173+
174+
def predict(self, X: pd.DataFrame) -> pd.Series:
175+
"""Makes predictions using the fitted model."""
176+
return self._estimator.predict(X)
177+
178+
def predict_proba(self, X: pd.DataFrame) -> pd.DataFrame:
179+
"""Returns probability estimates for predictions."""
180+
return self._estimator.predict_proba(X)

ml_grid/pipeline/model_class_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
)
5656
from ml_grid.model_classes.svc_class import SVCClass
5757
from ml_grid.model_classes.xgb_classifier_class import XGBClassifierClass
58+
from ml_grid.model_classes.tabpfn_classifier_class import TabPFNClassifierClass
5859

5960

6061
# --- ROBUST MAPPING of config names to class objects ---
@@ -79,6 +80,7 @@
7980
"QuadraticDiscriminantAnalysisClass": QuadraticDiscriminantAnalysisClass,
8081
"SVCClass": SVCClass,
8182
"NeuralNetworkClassifier_class": NeuralNetworkClassifier_class, # Corrected mapping
83+
"TabPFNClassifierClass": TabPFNClassifierClass,
8284
# GPU specific
8385
"KerasClassifierClass": KerasClassifierClass,
8486
# "KNNGpuWrapperClass": KNNGpuWrapperClass, #deprecated by python 3.12 and simsig dependency
@@ -163,6 +165,7 @@ def get_model_class_list(ml_grid_object: pipe) -> List[Any]:
163165
"H2O_XGBoost_class": True, # H2O XGBoost
164166
"H2O_StackedEnsemble_class": True, # H2O Stacked Ensemble
165167
"H2O_GAM_class": True, # H2O Generalized Additive Models
168+
"TabPFNClassifierClass": False, # requires hf token and agreement
166169
}
167170

168171
# If running in a CI environment, explicitly disable resource-intensive models

ml_grid/util/project_score_save.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,16 @@ def update_score_log(
271271
f_list.append(current_f_vector)
272272

273273
row_data["algorithm_implementation"] = current_algorithm
274-
row_data["parameter_sample"] = current_algorithm.get_params()
274+
275+
# Filter out large data objects from parameters to prevent logging errors and bloat
276+
params = current_algorithm.get_params()
277+
safe_params = {}
278+
for k, v in params.items():
279+
# Skip data arguments and large pandas/numpy objects
280+
if k not in ['X', 'y', 'data', 'validation_frame', 'training_frame'] and \
281+
not isinstance(v, (pd.DataFrame, pd.Series, np.ndarray)):
282+
safe_params[k] = v
283+
row_data["parameter_sample"] = safe_params
275284
row_data["method_name"] = method_name
276285
row_data["nb_size"] = sum(np.array(current_f_vector))
277286
row_data["n_features"] = len(current_f_vector)

0 commit comments

Comments
 (0)