@@ -438,7 +438,7 @@ def make_parallel_fit_v2(
438438import pandas as pd
439439import time
440440
441- def make_parallel_fit_fast (
441+ def make_parallel_fit_v3 (
442442 df : pd .DataFrame ,
443443 * ,
444444 gb_columns ,
@@ -449,7 +449,7 @@ def make_parallel_fit_fast(
449449 suffix : str = "_fast" ,
450450 selection = None ,
451451 addPrediction : bool = False ,
452- cast_dtype ,
452+ cast_dtype : Union [ str , None ] = "float32" ,
453453 diag : bool = True ,
454454 diag_prefix : str = "diag_" ,
455455 min_stat : 3 ,
@@ -617,3 +617,269 @@ def make_parallel_fit_fast(
617617 return df_out , dfGB .reset_index (drop = True )
618618
619619
620+
621+ # ======================================================================
622+ # Phase 4 — Numba-accelerated per-group OLS (weighted) — make_parallel_fit_v4
623+ # ======================================================================
624+
625+ # Numba import (safe; we fall back if absent)
626+ try :
627+ from numba import njit
628+ _NUMBA_OK = True
629+ except Exception :
630+ _NUMBA_OK = False
631+
632+
633+ if _NUMBA_OK :
634+ @njit (fastmath = True )
635+ def _ols_kernel_numba_weighted (X_all , Y_all , W_all , offsets , n_groups , n_feat , n_tgt , min_stat , out_beta ):
636+ """
637+ Weighted per-group OLS with intercept, compiled in nopython mode.
638+
639+ Parameters
640+ ----------
641+ X_all : (N, n_feat) float64
642+ Y_all : (N, n_tgt) float64
643+ W_all : (N,) float64 (weights; use 1.0 if unweighted)
644+ offsets : (G+1,) int32 (group start indices in sorted arrays)
645+ n_groups: int
646+ n_feat : int
647+ n_tgt : int
648+ min_stat: int
649+ out_beta: (G, n_feat+1, n_tgt) float64 (beta rows: [intercept, slopes...])
650+ """
651+ p = n_feat + 1 # intercept + features
652+ for g in range (n_groups ):
653+ i0 = offsets [g ]
654+ i1 = offsets [g + 1 ]
655+ m = i1 - i0
656+ if m < min_stat or m <= n_feat :
657+ # insufficient stats to solve (or underdetermined)
658+ continue
659+
660+ # Build X1 with intercept
661+ # X1 shape: (m, p)
662+ # X1[:,0] = 1
663+ # X1[:,1:] = X_all[i0:i1]
664+ X1 = np .ones ((m , p ))
665+ Xg = X_all [i0 :i1 ]
666+ for r in range (m ):
667+ for c in range (n_feat ):
668+ X1 [r , c + 1 ] = Xg [r , c ]
669+
670+ # Weighted normal equations:
671+ # XtX = Σ_r w_r * x_r x_r^T
672+ # XtY = Σ_r w_r * x_r y_r^T
673+ XtX = np .empty ((p , p ))
674+ for i in range (p ):
675+ for j in range (p ):
676+ s = 0.0
677+ for r in range (m ):
678+ wr = W_all [i0 + r ]
679+ s += wr * X1 [r , i ] * X1 [r , j ]
680+ XtX [i , j ] = s
681+
682+ Yg = Y_all [i0 :i1 ]
683+ XtY = np .empty ((p , n_tgt ))
684+ for i in range (p ):
685+ for t in range (n_tgt ):
686+ s = 0.0
687+ for r in range (m ):
688+ wr = W_all [i0 + r ]
689+ s += wr * X1 [r , i ] * Yg [r , t ]
690+ XtY [i , t ] = s
691+
692+ # Solve XtX * B = XtY via Gauss–Jordan with partial pivoting
693+ A = XtX .copy ()
694+ B = XtY .copy ()
695+
696+ for k in range (p ):
697+ # pivot search
698+ piv = k
699+ amax = abs (A [k , k ])
700+ for i in range (k + 1 , p ):
701+ v = abs (A [i , k ])
702+ if v > amax :
703+ amax = v
704+ piv = i
705+ # robust guard for near singular
706+ if amax < 1e-12 :
707+ # singular; leave zeros for this group
708+ for ii in range (p ):
709+ for tt in range (n_tgt ):
710+ out_beta [g , ii , tt ] = 0.0
711+ break
712+
713+ # row swap if needed
714+ if piv != k :
715+ for j in range (p ):
716+ tmp = A [k , j ]; A [k , j ] = A [piv , j ]; A [piv , j ] = tmp
717+ for tt in range (n_tgt ):
718+ tmp = B [k , tt ]; B [k , tt ] = B [piv , tt ]; B [piv , tt ] = tmp
719+
720+ pivval = A [k , k ]
721+ invp = 1.0 / pivval
722+ A [k , k ] = 1.0
723+ for j in range (k + 1 , p ):
724+ A [k , j ] *= invp
725+ for tt in range (n_tgt ):
726+ B [k , tt ] *= invp
727+
728+ for i in range (p ):
729+ if i == k :
730+ continue
731+ f = A [i , k ]
732+ if f != 0.0 :
733+ A [i , k ] = 0.0
734+ for j in range (k + 1 , p ):
735+ A [i , j ] -= f * A [k , j ]
736+ for tt in range (n_tgt ):
737+ B [i , tt ] -= f * B [k , tt ]
738+
739+ # write solution β
740+ for i in range (p ):
741+ for tt in range (n_tgt ):
742+ out_beta [g , i , tt ] = B [i , tt ]
743+
744+
745+ def make_parallel_fit_v4 (
746+ df : pd .DataFrame ,
747+ * ,
748+ gb_columns ,
749+ fit_columns ,
750+ linear_columns ,
751+ median_columns = None ,
752+ weights = None ,
753+ suffix : str = "_v4" ,
754+ selection = None ,
755+ addPrediction : bool = False ,
756+ cast_dtype : str = "float64" ,
757+ diag : bool = True ,
758+ diag_prefix : str = "diag_" ,
759+ min_stat : int = 3 ,
760+ ):
761+ """
762+ Phase 4 — Numba-accelerated per-group **weighted** OLS.
763+ - Same schema and user-facing behavior as v3 (intercept + slopes + optional predictions).
764+ - Supports 1 or multi-column group keys.
765+ - If Numba is unavailable, falls back to a pure-NumPy weighted loop.
766+ """
767+ t0 = time .perf_counter ()
768+ if median_columns is None :
769+ median_columns = []
770+ if isinstance (min_stat , (list , tuple )):
771+ min_stat = int (np .max (min_stat ))
772+
773+ # Selection
774+ df_use = df .loc [selection ] if selection is not None else df
775+
776+ # Stable sort by group columns to form contiguous blocks
777+ sort_keys = gb_columns if isinstance (gb_columns , (list , tuple )) else [gb_columns ]
778+ df_sorted = df_use .sort_values (sort_keys , kind = "mergesort" ).reset_index (drop = True )
779+
780+ # Build group IDs & offsets for single or multi-key groupby
781+ if len (sort_keys ) == 1 :
782+ key = sort_keys [0 ]
783+ key_vals = df_sorted [key ].to_numpy ()
784+ uniq_keys , start_idx = np .unique (key_vals , return_index = True )
785+ group_offsets = np .empty (len (uniq_keys ) + 1 , dtype = np .int32 )
786+ group_offsets [:- 1 ] = start_idx .astype (np .int32 )
787+ group_offsets [- 1 ] = len (df_sorted )
788+ n_groups = len (uniq_keys )
789+ group_id_rows = {key : uniq_keys }
790+ else :
791+ # Structured array unique for multi-key grouping
792+ rec = df_sorted [sort_keys ].to_records (index = False )
793+ uniq_rec , start_idx = np .unique (rec , return_index = True )
794+ group_offsets = np .empty (len (uniq_rec ) + 1 , dtype = np .int32 )
795+ group_offsets [:- 1 ] = start_idx .astype (np .int32 )
796+ group_offsets [- 1 ] = len (df_sorted )
797+ n_groups = len (uniq_rec )
798+ # Convert structured uniques back into dict of arrays for DataFrame assembly
799+ group_id_rows = {name : uniq_rec [name ] for name in uniq_rec .dtype .names }
800+
801+ # Flattened matrices
802+ X_all = df_sorted [linear_columns ].to_numpy (dtype = np .float64 , copy = False )
803+ Y_all = df_sorted [fit_columns ].to_numpy (dtype = np .float64 , copy = False )
804+
805+ # Weights: ones if not provided
806+ if weights is None :
807+ W_all = np .ones (len (df_sorted ), dtype = np .float64 )
808+ else :
809+ W_all = df_sorted [weights ].to_numpy (dtype = np .float64 , copy = False )
810+
811+ n_feat = X_all .shape [1 ]
812+ n_tgt = Y_all .shape [1 ]
813+ beta = np .zeros ((n_groups , n_feat + 1 , n_tgt ), dtype = np .float64 )
814+
815+ if _NUMBA_OK :
816+ _ols_kernel_numba_weighted (
817+ X_all , Y_all , W_all , group_offsets .astype (np .int32 ),
818+ n_groups , n_feat , n_tgt , int (min_stat ), beta
819+ )
820+ else :
821+ # Pure NumPy fallback (weighted)
822+ p = n_feat + 1
823+ for g in range (n_groups ):
824+ i0 , i1 = group_offsets [g ], group_offsets [g + 1 ]
825+ m = i1 - i0
826+ if m < min_stat or m <= n_feat :
827+ continue
828+ Xg = X_all [i0 :i1 ]
829+ Yg = Y_all [i0 :i1 ]
830+ Wg = W_all [i0 :i1 ] # shape (m,)
831+ # Build X1 with intercept
832+ X1 = np .c_ [np .ones (m ), Xg ] # (m, p)
833+ # Weighted normal equations
834+ # XtX = X1^T * W * X1 ; XtY = X1^T * W * Yg
835+ XtX = (X1 .T * Wg ).dot (X1 ) # (p,p)
836+ XtY = (X1 .T * Wg .reshape (- 1 ,)).dot (Yg ) # (p,n_tgt)
837+ try :
838+ B = np .linalg .solve (XtX , XtY )
839+ beta [g , :, :] = B
840+ except np .linalg .LinAlgError :
841+ # leave zeros for this group
842+ pass
843+
844+ # Assemble dfGB (same schema as v3)
845+ rows = []
846+ for gi in range (n_groups ):
847+ row = {}
848+ # write group id columns
849+ for k , col in enumerate (group_id_rows .keys ()):
850+ row [col ] = group_id_rows [col ][gi ]
851+ # write coefficients
852+ for t_idx , tname in enumerate (fit_columns ):
853+ row [f"{ tname } _intercept{ suffix } " ] = beta [gi , 0 , t_idx ]
854+ for j , cname in enumerate (linear_columns , start = 1 ):
855+ row [f"{ tname } _slope_{ cname } { suffix } " ] = beta [gi , j , t_idx ]
856+ rows .append (row )
857+
858+ dfGB = pd .DataFrame (rows )
859+
860+ # Diagnostics (minimal; mirrors v3 style)
861+ if diag :
862+ dfGB [f"{ diag_prefix } wall_ms" ] = (time .perf_counter () - t0 ) * 1e3
863+ dfGB [f"{ diag_prefix } n_groups" ] = len (dfGB )
864+
865+ # Optional cast
866+ if cast_dtype is not None and len (dfGB ):
867+ # Don't cast the group key columns
868+ safe_keys = sort_keys
869+ dfGB = dfGB .astype ({
870+ c : cast_dtype
871+ for c in dfGB .columns
872+ if c not in safe_keys and dfGB [c ].dtype == "float64"
873+ })
874+
875+ # Optional prediction join
876+ df_out = df_use .copy ()
877+ if addPrediction and len (dfGB ):
878+ df_out = df_out .merge (dfGB , on = sort_keys , how = "left" )
879+ for t in fit_columns :
880+ pred = df_out [f"{ t } _intercept{ suffix } " ].copy ()
881+ for cname in linear_columns :
882+ pred += df_out [f"{ t } _slope_{ cname } { suffix } " ] * df_out [cname ]
883+ df_out [f"{ t } _pred{ suffix } " ] = pred
884+
885+ return df_out , dfGB
0 commit comments