Skip to content

Fix Scikit-Learn API compliance and GridSearchCV support (#340, #482)#490

Open
VanshKharb wants to merge 9 commits intodswah:mainfrom
VanshKharb:fix/sklearn-compliance-and-gridsearchcv
Open

Fix Scikit-Learn API compliance and GridSearchCV support (#340, #482)#490
VanshKharb wants to merge 9 commits intodswah:mainfrom
VanshKharb:fix/sklearn-compliance-and-gridsearchcv

Conversation

@VanshKharb
Copy link
Copy Markdown

Summary

This PR resolves several compatibility issues with Scikit-Learn estimators, specifically addressing #340 (sklearn.clone attribute mutation) and #482 (GridSearchCV failures). Following these changes, pyGAM models now pass the full sklearn.utils.estimator_checks.check_estimator suite.

Detailed Changes

  • Fix estimator cloning (Returning no terms when cloning the model #340): Updated GAM.get_params() to return the original un-mutated init arguments. Previously, standard string arguments were replaced with objects during fit(), which caused sklearn.clone() to fail.
  • Fix GridSearchCV compatibility (LogisticGAM not usable with GridSearchCV #482): LogisticGAM.fit() now defines the classes_ attribute, and predict() correctly maps outputs to these class labels instead of returning boolean masks.
  • Input validation improvements: Added targeted ValueError and TypeError exceptions in check_array(), check_X(), and check_y() to correctly handle 1D arrays, dense requirement for sparse matrices, and zero-feature arrays based on Scikit-Learn guidelines.
  • Immutable defaults: Converted mutable default list arguments in init (e.g., callbacks) to tuples.
  • Estimator state properties: Added n_iter_ tracking to GAM._pirls to satisfy iterations checks.
  • Binary validation: Added explicit multiclass rejection to LogisticGAM.

Testing

  • Added tests/test_sklearn_gridsearch.py to verify end-to-end GridSearchCV compatibility with LogisticGAM.
  • Confirmed all existing unit tests in pygam/tests pass.

- Save original initialization params before `GAM.fit()` mutates them
- Override `get_params(deep=False)` to restore original params for `sklearn.clone()`
- Add `__sklearn_clone__` to `Term` and `TermList` to prevent empty list duck-typing
- Remove redundant `gam.set_params` calls in internal bootstrap/gridsearch
- Add regression tests for cloning fitted GAMs with custom terms
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant