3636"""
3737
3838from __future__ import annotations
39- import argparse , json , os , sys , time , uuid , platform , subprocess , inspect
39+ import argparse , json , os , sys , time , uuid , platform , subprocess
4040from dataclasses import dataclass , asdict
4141from pathlib import Path
4242from typing import Dict , Any , List , Tuple
@@ -84,11 +84,13 @@ def get_environment_info() -> Dict[str, Any]:
8484# ---------------- Imports (follow bench_comparison.py pattern) ----------------
8585def _import_implementations ():
8686 try :
87- from groupby_regression_optimized import (
87+ # Try package-relative import first
88+ from ..groupby_regression_optimized import (
8889 make_parallel_fit_v2 , make_parallel_fit_v3 , make_parallel_fit_v4
8990 )
9091 return ("package" , make_parallel_fit_v2 , make_parallel_fit_v3 , make_parallel_fit_v4 )
91- except Exception :
92+ except ImportError :
93+ # Fallback: add parent to path
9294 here = Path (__file__ ).resolve ()
9395 root = here .parent .parent
9496 sys .path .insert (0 , str (root ))
@@ -128,81 +130,21 @@ def _make_synthetic_data(n_groups: int, rows_per_group: int,
128130 })
129131 return df
130132
131- # ---------------- Signature-aware engine wrapper ----------------
132- _ALIAS_MAP = {
133- # canonical -> possible alternates
134- "gb_columns" : ["gb_columns" , "gbColumns" , "groupby_columns" ],
135- "fit_columns" : ["fit_columns" , "fitColumns" , "targets" ],
136- "linear_columns" : ["linear_columns" , "linearColumns" , "features" ],
137- "median_columns" : ["median_columns" , "medianColumns" ],
138- "weights" : ["weights" , "weight_column" ],
139- "suffix" : ["suffix" ],
140- "selection" : ["selection" , "mask" ],
141- "addPrediction" : ["addPrediction" , "add_prediction" ],
142- "n_jobs" : ["n_jobs" , "nThreads" , "n_workers" ],
143- "min_stat" : ["min_stat" , "minStat" ],
144- "fitter" : ["fitter" ],
145- "sigmaCut" : ["sigmaCut" , "sigma_cut" ],
146- "batch_size" : ["batch_size" , "batchSize" ],
147- }
148-
149- def _normalize_kwargs_for_signature (fun , kwargs : Dict [str , Any ]) -> Dict [str , Any ]:
150- """Map/limit kwargs to what `fun` actually accepts."""
151- sig = inspect .signature (fun )
152- params = set (sig .parameters .keys ())
153- out : Dict [str , Any ] = {}
154-
155- # Build reverse alias map keyed by actual parameter names present
156- alias_candidates = {}
157- for canonical , alts in _ALIAS_MAP .items ():
158- for alt in alts :
159- alias_candidates [alt ] = canonical
160-
161- # First pass: if kw already matches a param, keep
162- for k , v in kwargs .items ():
163- if k in params :
164- out [k ] = v
165-
166- # Second pass: try alias mapping for missing ones
167- for k , v in kwargs .items ():
168- if k in out :
169- continue
170- # map k -> canonical, then see if any alias for canonical matches a real param
171- canonical = alias_candidates .get (k , None )
172- if not canonical :
173- continue
174- for alt in _ALIAS_MAP .get (canonical , []):
175- if alt in params :
176- out [alt ] = v
177- break
178-
179- # Special case: if neither 'addPrediction' nor 'add_prediction' present, but one is required
180- # we rely on 'params' to decide; otherwise ignore.
181- return out
182-
183- def _call_engine (fun , df : pd .DataFrame , ** kwargs ):
184- filt = _normalize_kwargs_for_signature (fun , kwargs )
185- return fun (df , ** filt )
186-
187133# ---------------- Numba warm-up ----------------
188- def warm_up_numba (make_parallel_fit_v4 , * , verbose : bool = True ) -> None :
189- df = _make_synthetic_data ( n_groups = 10 , rows_per_group = 5 , seed = 123 )
134+ def warm_up_numba (v4_fun , verbose : bool = False ) :
135+ """Trigger Numba JIT compilation before benchmarking."""
190136 try :
191- _call_engine (
192- make_parallel_fit_v4 , df ,
137+ df_tiny = _make_synthetic_data (10 , 5 , seed = 999 )
138+ _ = v4_fun (
139+ df = df_tiny ,
193140 gb_columns = ["g0" ,"g1" ,"g2" ],
194- fit_columns = ["y1" , "y2" ],
141+ fit_columns = ["y1" ],
195142 linear_columns = ["x" ],
196143 median_columns = [],
197144 weights = "wFit" ,
198- suffix = "_warm" ,
199- selection = pd .Series (np .ones (len (df ), dtype = bool )),
200- addPrediction = False ,
201- n_jobs = 1 , # dropped automatically if v4 doesn't accept it
202- min_stat = [3 ,3 ],
203- fitter = "ols" ,
204- sigmaCut = 100 ,
205- batch_size = "auto"
145+ suffix = "_warmup" ,
146+ selection = pd .Series (np .ones (len (df_tiny ), dtype = bool )),
147+ min_stat = 3
206148 )
207149 if verbose :
208150 print ("[warm-up] Numba v4 compilation done." )
@@ -248,23 +190,22 @@ def full_scenarios() -> List[Scenario]:
248190
249191# ---------------- Core runner ----------------
250192def _run_once (engine_name : str , fun , df : pd .DataFrame , sc : Scenario ) -> Tuple [float , Dict [str , Any ]]:
193+ """Run one engine on one scenario and return timing + metadata."""
251194 t0 = time .perf_counter ()
252- df_out , dfGB = _call_engine (
253- fun , df ,
195+
196+ # Call engine directly with keyword arguments
197+ df_out , dfGB = fun (
198+ df = df ,
254199 gb_columns = ["g0" ,"g1" ,"g2" ],
255200 fit_columns = ["y1" ,"y2" ],
256201 linear_columns = ["x" ],
257202 median_columns = [],
258203 weights = "wFit" ,
259204 suffix = "_b" ,
260205 selection = pd .Series (np .ones (len (df ), dtype = bool )),
261- addPrediction = False ,
262- n_jobs = sc .n_jobs , # dropped for engines that don't accept it
263- min_stat = [3 ,3 ],
264- fitter = sc .fitter ,
265- sigmaCut = sc .sigmaCut ,
266- batch_size = "auto"
206+ min_stat = 3
267207 )
208+
268209 elapsed = time .perf_counter () - t0
269210
270211 rows_total = len (df )
@@ -339,11 +280,8 @@ def parse_args():
339280def main ():
340281 args = parse_args ()
341282
342- source , v2_raw , v3_raw , v4_raw = _import_implementations ()
343- # Wrap engines with signature-aware caller to guarantee safe kwargs handling.
344- def wrap (fun ):
345- return lambda df , ** kw : _call_engine (fun , df , ** kw )
346- v2 , v3 , v4 = wrap (v2_raw ), wrap (v3_raw ), wrap (v4_raw )
283+ # Import implementations
284+ source , v2 , v3 , v4 = _import_implementations ()
347285
348286 env = get_environment_info ()
349287 ts = time .strftime ("%Y-%m-%d %H:%M:%S" , time .localtime ())
@@ -352,8 +290,8 @@ def wrap(fun):
352290 out_dir = Path (args .out )
353291 out_dir .mkdir (parents = True , exist_ok = True )
354292
355- # Warm-up JIT (filtered call)
356- warm_up_numba (v4_raw , verbose = True )
293+ # Warm-up JIT
294+ warm_up_numba (v4 , verbose = True )
357295
358296 scenarios = full_scenarios () if args .full else quick_scenarios ()
359297 engines = [("v2" , v2 ), ("v3" , v3 ), ("v4" , v4 )]
@@ -400,4 +338,4 @@ def wrap(fun):
400338 print (" -" , csv_path )
401339
402340if __name__ == "__main__" :
403- main ()
341+ main ()
0 commit comments