1+ #!/usr/bin/env python3
2+ """
3+ plot_groupby_regression_optimized.py
4+ Config-driven plotting for the optimized GroupBy benchmark.
5+
6+ Reads:
7+ benchmarks/bench_out/benchmark_summary.csv
8+
9+ Writes (defaults, can be changed with CLI):
10+ benchmarks/bench_out/throughput_by_engine.png
11+ benchmarks/bench_out/speedup_v4_over_v2.png
12+ benchmarks/bench_out/scaling_groups.png
13+ benchmarks/bench_out/scaling_rows_per_group.png
14+ benchmarks/bench_out/scaling_n_jobs.png
15+ """
16+
17+ from __future__ import annotations
18+ from dataclasses import dataclass
19+ from typing import Literal , List
20+ from pathlib import Path
21+ import argparse
22+
23+ import pandas as pd
24+ import matplotlib .pyplot as plt
25+
26+
27+ # ------------------------------- Plot Config API -------------------------------
28+
29+ PlotKind = Literal ["bar" , "line" , "speedup_v4_over_v2" ]
30+
31+ # Colorblind-friendly palette (Wong 2011)
32+ COLORS = ['#0173B2' , '#DE8F05' , '#029E73' , '#CC78BC' , '#CA9161' ]
33+
34+ @dataclass
35+ class PlotConfig :
36+ """Configuration for a single plot."""
37+ # Subset of rows to consider (pandas .query() string)
38+ query : str
39+ # Output
40+ filename : str
41+ title : str
42+ # Semantics
43+ kind : PlotKind # "bar" | "line" | "speedup_v4_over_v2"
44+ x_axis : str # e.g., "scenario_id" | "n_groups" | "rows_per_group" | "n_jobs"
45+ y_axis : str = "groups_per_s" # metric to plot
46+ legend : str = "engine" # which column defines the series (legend)
47+ log_y : bool = False
48+ # Line-plot specifics
49+ agg : Literal ["median" , "mean" , "max" ] = "median"
50+ min_points : int = 3 # require at least N distinct x values
51+
52+
53+ # ----------------------------- Default configurations --------------------------
54+
55+ PLOT_CONFIGS : List [PlotConfig ] = [
56+ # 1) Throughput summary by engine (grouped bar)
57+ PlotConfig (
58+ query = "engine in ['v2','v3','v4']" ,
59+ filename = "throughput_by_engine.png" ,
60+ title = "Throughput by Engine (higher is better)" ,
61+ kind = "bar" ,
62+ x_axis = "scenario_id" ,
63+ log_y = True ,
64+ ),
65+ # 2) Speedup bar: v4 over v2 per scenario
66+ PlotConfig (
67+ query = "engine in ['v2','v4']" ,
68+ filename = "speedup_v4_over_v2.png" ,
69+ title = "Speedup of Numba v4 over v2 (higher is better)" ,
70+ kind = "speedup_v4_over_v2" ,
71+ x_axis = "scenario_id" ,
72+ ),
73+ # 3) Scaling vs n_groups (line)
74+ PlotConfig (
75+ query = "engine in ['v2','v3','v4']" ,
76+ filename = "scaling_groups.png" ,
77+ title = "Scaling vs n_groups" ,
78+ kind = "line" ,
79+ x_axis = "n_groups" ,
80+ log_y = True ,
81+ ),
82+ # 4) Scaling vs rows_per_group (line)
83+ PlotConfig (
84+ query = "engine in ['v2','v3','v4']" ,
85+ filename = "scaling_rows_per_group.png" ,
86+ title = "Scaling vs rows_per_group" ,
87+ kind = "line" ,
88+ x_axis = "rows_per_group" ,
89+ log_y = True ,
90+ ),
91+ # 5) Scaling vs n_jobs (line)
92+ PlotConfig (
93+ query = "engine in ['v2','v3','v4']" ,
94+ filename = "scaling_n_jobs.png" ,
95+ title = "Scaling vs n_jobs" ,
96+ kind = "line" ,
97+ x_axis = "n_jobs" ,
98+ log_y = True ,
99+ ),
100+ ]
101+
102+
103+ # ------------------------------- Helper functions ------------------------------
104+
105+ def parse_args ():
106+ p = argparse .ArgumentParser (description = "Plot optimized GroupBy benchmark results (config-driven)." )
107+ p .add_argument ("--csv" , type = str ,
108+ default = str (Path (__file__ ).resolve ().parent / "bench_out" / "benchmark_summary.csv" ),
109+ help = "Path to benchmark_summary.csv" )
110+ p .add_argument ("--outdir" , type = str ,
111+ default = str (Path (__file__ ).resolve ().parent / "bench_out" ),
112+ help = "Output directory for plots" )
113+ p .add_argument ("--fmt" , choices = ["png" , "svg" ], default = "png" , help = "Image format" )
114+ p .add_argument ("--dpi" , type = int , default = 140 , help = "DPI for PNG" )
115+ p .add_argument ("--only" , nargs = "*" , default = [],
116+ help = "Optional list of output filenames to generate (filters PLOT_CONFIGS)" )
117+ return p .parse_args ()
118+
119+
120+ def _safe_category_order (series : pd .Series ) -> list :
121+ """Stable order for categorical x (e.g., scenario_id)."""
122+ if pd .api .types .is_categorical_dtype (series ):
123+ return list (series .cat .categories )
124+ # preserve first-seen order
125+ seen , order = set (), []
126+ for v in series :
127+ if v not in seen :
128+ seen .add (v )
129+ order .append (v )
130+ return order
131+
132+
133+ def render_bar (df : pd .DataFrame , cfg : PlotConfig , outdir : Path , fmt : str , dpi : int ):
134+ """Render grouped bar chart."""
135+ try :
136+ d = df .query (cfg .query ).copy ()
137+ if d .empty :
138+ print (f"[skip] { cfg .filename } : no data after query '{ cfg .query } '" )
139+ return None
140+
141+ # Use median aggregation instead of "first"
142+ pv = d .pivot_table (index = cfg .x_axis , columns = cfg .legend ,
143+ values = cfg .y_axis , aggfunc = "median" )
144+
145+ scenarios = list (pv .index .astype (str ))
146+ engines = list (pv .columns .astype (str ))
147+
148+ n_sc = len (scenarios )
149+ n_eng = len (engines )
150+ if n_sc == 0 or n_eng == 0 :
151+ print (f"[skip] { cfg .filename } : empty pivot table" )
152+ return None
153+
154+ width = 0.22
155+ xs = list (range (n_sc ))
156+
157+ fig , ax = plt .subplots (figsize = (max (9 , n_sc * 0.6 ), 5.5 ))
158+ for j , eng in enumerate (engines ):
159+ xj = [x + (j - (n_eng - 1 )/ 2 )* width for x in xs ]
160+ ax .bar (xj , pv [eng ].values , width = width ,
161+ color = COLORS [j % len (COLORS )], label = eng )
162+
163+ ax .set_xticks (xs )
164+ ax .set_xticklabels (scenarios , rotation = 30 , ha = "right" )
165+ ax .set_ylabel (cfg .y_axis .replace ("_" , " " ))
166+ ax .set_title (cfg .title )
167+ if cfg .log_y :
168+ ax .set_yscale ("log" )
169+ ax .grid (axis = "y" , alpha = 0.2 )
170+ ax .legend ()
171+
172+ out = outdir / cfg .filename
173+ if fmt == "png" :
174+ fig .savefig (out , dpi = dpi , bbox_inches = "tight" )
175+ else :
176+ fig .savefig (out , bbox_inches = "tight" )
177+ plt .close (fig )
178+ return out
179+
180+ except Exception as e :
181+ print (f"[error] { cfg .filename } : { e } " )
182+ return None
183+
184+
185+ def render_speedup (df : pd .DataFrame , cfg : PlotConfig , outdir : Path , fmt : str , dpi : int ):
186+ """Render speedup comparison (v4 / v2)."""
187+ try :
188+ d = df .query (cfg .query ).copy ()
189+ if d .empty :
190+ print (f"[skip] { cfg .filename } : no data after query '{ cfg .query } '" )
191+ return None
192+
193+ X = cfg .x_axis
194+
195+ base = d [d [cfg .legend ] == "v2" ][[X , cfg .y_axis ]].rename (columns = {cfg .y_axis : "v2" })
196+ v4 = d [d [cfg .legend ] == "v4" ][[X , cfg .y_axis ]].rename (columns = {cfg .y_axis : "v4" })
197+ m = base .merge (v4 , on = X , how = "inner" )
198+
199+ if m .empty :
200+ print (f"[skip] { cfg .filename } : no matching v2/v4 data" )
201+ return None
202+
203+ m ["speedup" ] = m ["v4" ] / m ["v2" ]
204+
205+ scenarios = list (m [X ].astype (str ).values )
206+ vals = m ["speedup" ].values
207+
208+ fig , ax = plt .subplots (figsize = (max (9 , len (scenarios ) * 0.6 ), 5.0 ))
209+ ax .bar (range (len (scenarios )), vals , color = COLORS [0 ])
210+ ax .set_xticks (range (len (scenarios )))
211+ ax .set_xticklabels (scenarios , rotation = 30 , ha = "right" )
212+ ax .set_ylabel ("speedup (v4 ÷ v2)" )
213+ ax .set_title (cfg .title )
214+ ax .grid (axis = "y" , alpha = 0.2 )
215+
216+ # Label bars with speedup value
217+ for i , v in enumerate (vals ):
218+ if v > 5 :
219+ ax .text (i , v , f"{ v :.0f} ×" , ha = "center" , va = "bottom" , fontsize = 9 )
220+
221+ out = outdir / cfg .filename
222+ if fmt == "png" :
223+ fig .savefig (out , dpi = dpi , bbox_inches = "tight" )
224+ else :
225+ fig .savefig (out , bbox_inches = "tight" )
226+ plt .close (fig )
227+ return out
228+
229+ except Exception as e :
230+ print (f"[error] { cfg .filename } : { e } " )
231+ return None
232+
233+
234+ def render_line (df : pd .DataFrame , cfg : PlotConfig , outdir : Path , fmt : str , dpi : int ):
235+ """Render line plot (scaling analysis)."""
236+ try :
237+ d = df .query (cfg .query ).copy ()
238+ if d .empty :
239+ print (f"[skip] { cfg .filename } : no data after query '{ cfg .query } '" )
240+ return None
241+
242+ X = cfg .x_axis
243+ Y = cfg .y_axis
244+ L = cfg .legend
245+
246+ # Aggregate by (legend, X) with selected reducer
247+ reducer = {"median" : "median" , "mean" : "mean" , "max" : "max" }[cfg .agg ]
248+ g = d .groupby ([L , X ], as_index = False )[Y ].agg (reducer )
249+
250+ # Filter out too-short series
251+ counts = g .groupby (L )[X ].nunique ()
252+ keep = set (counts [counts >= cfg .min_points ].index )
253+ g = g [g [L ].isin (keep )]
254+
255+ if g .empty :
256+ print (f"[skip] { cfg .filename } : insufficient data points (need { cfg .min_points } )" )
257+ return None
258+
259+ # Order X
260+ try :
261+ x_sorted = sorted (g [X ].unique ())
262+ except Exception :
263+ x_sorted = _safe_category_order (g [X ])
264+
265+ fig , ax = plt .subplots (figsize = (max (9 , len (x_sorted ) * 0.6 ), 5.5 ))
266+ for idx , (key , part ) in enumerate (g .groupby (L )):
267+ # align to sorted X
268+ part = part .set_index (X ).reindex (x_sorted )
269+ ax .plot (x_sorted , part [Y ].values , marker = "o" ,
270+ color = COLORS [idx % len (COLORS )], label = str (key ))
271+
272+ ax .set_xlabel (X .replace ("_" , " " ))
273+ ax .set_ylabel (Y .replace ("_" , " " ))
274+ ax .set_title (cfg .title )
275+ if cfg .log_y :
276+ ax .set_yscale ("log" )
277+ ax .grid (True , alpha = 0.25 )
278+ ax .legend (title = L )
279+
280+ out = outdir / cfg .filename
281+ if fmt == "png" :
282+ fig .savefig (out , dpi = dpi , bbox_inches = "tight" )
283+ else :
284+ fig .savefig (out , bbox_inches = "tight" )
285+ plt .close (fig )
286+ return out
287+
288+ except Exception as e :
289+ print (f"[error] { cfg .filename } : { e } " )
290+ return None
291+
292+
293+ # ------------------------------------- Main ------------------------------------
294+
295+ def main ():
296+ args = parse_args ()
297+ csv_path = Path (args .csv )
298+ outdir = Path (args .outdir )
299+ outdir .mkdir (parents = True , exist_ok = True )
300+
301+ if not csv_path .exists ():
302+ print (f"[error] CSV not found: { csv_path } " )
303+ return
304+
305+ df = pd .read_csv (csv_path )
306+
307+ if df .empty :
308+ print ("[error] CSV is empty" )
309+ return
310+
311+ print (f"Loaded { len (df )} rows from { csv_path } " )
312+
313+ generated = []
314+ for cfg in PLOT_CONFIGS :
315+ if args .only and cfg .filename not in args .only :
316+ continue
317+
318+ if cfg .kind == "bar" :
319+ out = render_bar (df , cfg , outdir , args .fmt , args .dpi )
320+ elif cfg .kind == "line" :
321+ out = render_line (df , cfg , outdir , args .fmt , args .dpi )
322+ elif cfg .kind == "speedup_v4_over_v2" :
323+ out = render_speedup (df , cfg , outdir , args .fmt , args .dpi )
324+ else :
325+ out = None
326+
327+ if out :
328+ print (f"[wrote] { out } " )
329+ generated .append (out )
330+
331+ if generated :
332+ print (f"\n Generated { len (generated )} plot(s)" )
333+ else :
334+ print ("\n [warning] No plots generated" )
335+
336+
337+ if __name__ == "__main__" :
338+ main ()
0 commit comments