1+ """Defines the TabPFN Classifier model class."""
2+
3+ import logging
4+ from typing import Any , Dict , Optional
5+
6+ import numpy as np
7+ import pandas as pd
8+ import torch
9+ from sklearn .base import BaseEstimator , ClassifierMixin
10+ from skopt .space import Categorical , Integer , Real
11+
12+ from ml_grid .util import param_space
13+ from ml_grid .util .global_params import global_parameters
14+
15+ try :
16+ from tabpfn import TabPFNClassifier
17+ from tabpfn .constants import ModelVersion
18+ TABPFN_AVAILABLE = True
19+ except ImportError :
20+ TABPFN_AVAILABLE = False
21+ logging .getLogger ("ml_grid" ).warning (
22+ "TabPFN not available. Install with: pip install tabpfn"
23+ )
24+
25+ logging .getLogger ("ml_grid" ).debug ("Imported TabPFNClassifier class" )
26+
27+
28+ class TabPFNClassifierClass (BaseEstimator , ClassifierMixin ):
29+ """TabPFN Classifier with support for hyperparameter tuning.
30+
31+ TabPFN is a foundation model for tabular data that performs well on small
32+ to medium-sized datasets (up to 50,000 rows). It requires GPU for optimal
33+ performance on datasets larger than ~1000 samples.
34+
35+ Note: TabPFN-2.5 model weights require accepting license terms at:
36+ https://huggingface.co/Prior-Labs/tabpfn_2_5
37+ """
38+
39+ def __init__ (
40+ self ,
41+ parameter_space_size : Optional [str ] = None ,
42+ # Hyperparameters for scikit-learn compatibility
43+ model_version : str = "v2.5_default" ,
44+ device : str = "cpu" ,
45+ n_estimators : int = 4 ,
46+ subsample_samples : Optional [int ] = None ,
47+ random_state : int = 42 ,
48+ ):
49+ """Initializes the TabPFNClassifierClass.
50+
51+ Args:
52+ parameter_space_size (Optional[str]): Size of the parameter space for
53+ optimization. Defaults to None.
54+
55+ model_version (str): The version of the TabPFN model to use.
56+ device (str): The device to run the model on ('cpu' or 'cuda').
57+ n_estimators (int): Number of ensemble members.
58+ subsample_samples (Optional[int]): Subsample size for large datasets.
59+ random_state (int): Random state for reproducibility.
60+ Raises:
61+ ImportError: If TabPFN is not installed.
62+ """
63+ if not TABPFN_AVAILABLE :
64+ raise ImportError (
65+ "TabPFN is not installed. Install with: pip install tabpfn"
66+ )
67+
68+ # Store scikit-learn hyperparameters
69+ self .model_version = model_version
70+ self .device = device
71+ self .n_estimators = n_estimators
72+ self .subsample_samples = subsample_samples
73+ self .random_state = random_state
74+
75+ global_params = global_parameters
76+ self .parameter_space_size = parameter_space_size
77+
78+ self .algorithm_implementation = self # The instance itself is the estimator
79+ self .method_name : str = "TabPFNClassifier"
80+
81+ self .parameter_vector_space : param_space .ParamSpace = param_space .ParamSpace (
82+ parameter_space_size
83+ )
84+ self .parameter_space : Dict [str , Any ]
85+
86+ if global_params .bayessearch :
87+ self .parameter_space = {
88+ # Model version selection
89+ "model_version" : Categorical ([
90+ "v2.5_default" , # Default: finetuned on real data
91+ "v2.5_synthetic" , # Trained on synthetic data only
92+ "v2" # TabPFN v2
93+ ]),
94+
95+ # Device selection - can be optimized based on availability
96+ "device" : Categorical (["cuda" , "cpu" ]),
97+
98+ # Number of ensemble members (more = better but slower)
99+ "n_estimators" : Integer (1 , 8 ),
100+
101+ # Training subsample size (for large datasets)
102+ "subsample_samples" : Categorical ([None , 5000 , 10000 , 20000 ]),
103+
104+ # Random state for reproducibility
105+ "random_state" : Categorical ([42 ]),
106+ }
107+
108+ else :
109+ self .parameter_space = {
110+ "model_version" : ["v2.5_default" , "v2.5_synthetic" , "v2" ],
111+ "device" : ["cuda" , "cpu" ],
112+ "n_estimators" : [1 , 2 , 4 , 8 ],
113+ "subsample_samples" : [None , 5000 , 10000 , 20000 ],
114+ "random_state" : [42 ],
115+ }
116+
117+ def fit (self , X : pd .DataFrame , y : pd .Series ):
118+ """Fits the TabPFN model.
119+
120+ This method uses the hyperparameters set on the instance to create
121+ and fit the underlying TabPFNClassifier.
122+ """
123+ # Apply subsampling if configured
124+ if self .subsample_samples is not None and len (X ) > self .subsample_samples :
125+ # Use numpy for stable random sampling
126+ rng = np .random .RandomState (self .random_state )
127+ indices = rng .choice (len (X ), self .subsample_samples , replace = False )
128+
129+ # Handle DataFrame/Series or numpy arrays
130+ if isinstance (X , pd .DataFrame ):
131+ X = X .iloc [indices ]
132+ else :
133+ X = X [indices ]
134+
135+ if isinstance (y , pd .Series ):
136+ y = y .iloc [indices ]
137+ else :
138+ y = y [indices ]
139+
140+ # Get the hyperparameters from the instance itself
141+ params = self .get_params ()
142+
143+ # Check for GPU availability and fallback if necessary
144+ if params .get ("device" ) == "cuda" and not torch .cuda .is_available ():
145+ logging .getLogger ("ml_grid" ).warning (
146+ "TabPFN device set to 'cuda' but no CUDA GPU found. Falling back to 'cpu'."
147+ )
148+ params ["device" ] = "cpu"
149+
150+ # This logic was originally in create_model
151+ model_version = params .pop ("model_version" , "v2.5_default" )
152+
153+ # Filter out non-TabPFN params that might be in get_params()
154+ valid_tabpfn_params = [
155+ "device" , "n_estimators" ,
156+ "random_state"
157+ ]
158+ params_copy = {k : v for k , v in params .items () if k in valid_tabpfn_params }
159+
160+ if model_version == "v2.5_synthetic" :
161+ params_copy ["model_path" ] = "tabpfn-v2.5-classifier-v2.5_default-2.ckpt"
162+
163+ if model_version == "v2" :
164+ self ._estimator = TabPFNClassifier .create_default_for_version (
165+ ModelVersion .V2 , ** params_copy
166+ )
167+ else :
168+ self ._estimator = TabPFNClassifier (** params_copy )
169+
170+ self ._estimator .fit (X , y )
171+ self .classes_ = self ._estimator .classes_
172+ return self
173+
174+ def predict (self , X : pd .DataFrame ) -> pd .Series :
175+ """Makes predictions using the fitted model."""
176+ return self ._estimator .predict (X )
177+
178+ def predict_proba (self , X : pd .DataFrame ) -> pd .DataFrame :
179+ """Returns probability estimates for predictions."""
180+ return self ._estimator .predict_proba (X )
0 commit comments