diff --git a/python/interpret-core/interpret/blackbox/_lime.py b/python/interpret-core/interpret/blackbox/_lime.py index caec36cfe..ea770bb97 100644 --- a/python/interpret-core/interpret/blackbox/_lime.py +++ b/python/interpret-core/interpret/blackbox/_lime.py @@ -1,6 +1,8 @@ # Copyright (c) 2023 The InterpretML Contributors # Distributed under the MIT software license +import warnings + import numpy as np from ..core.base import LocalExplainer @@ -46,8 +48,26 @@ def __init__(self, model, data, feature_names=None, feature_types=None, **kwargs # so convert to np.float64 until we implement some automatic categorical handling data = data.astype(np.float64, order="C", copy=False) - # rewrite these even if the user specified them + # `LimeTabular` always runs the underlying LIME explainer in + # "regression" mode regardless of whether the wrapped model is a + # regressor or a binary classifier — `unify_predict_fn` below + # turns classifier outputs into a scalar probability in [0, 1] so + # LIME treats both cases uniformly. The overrides for `mode` and + # `feature_names` are therefore intentional, but issue #477 showed + # they are surprising when a user sets `mode="classification"` + # explicitly. Warn if their value is being discarded so they know + # to read the docstring rather than wonder why nothing happened. kwargs = kwargs.copy() + user_mode = kwargs.get("mode", "regression") + if user_mode != "regression": + warnings.warn( + "LimeTabular wraps LIME in 'regression' mode internally for " + "both regression and binary classification (the predict " + "function is unified to return a scalar). The " + f"`mode={user_mode!r}` argument was ignored.", + UserWarning, + stacklevel=2, + ) kwargs["mode"] = "regression" kwargs["feature_names"] = self.feature_names_in_ diff --git a/python/interpret-core/tests/blackbox/test_lime.py b/python/interpret-core/tests/blackbox/test_lime.py new file mode 100644 index 000000000..e7c62ca6f --- /dev/null +++ b/python/interpret-core/tests/blackbox/test_lime.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 The InterpretML Contributors +# Distributed under the MIT software license + +"""Tests for LimeTabular wrapper. + +Issue #477: passing ``mode="classification"`` to ``LimeTabular`` was +silently overridden to ``"regression"`` because the wrapper unifies +classifier and regressor predict functions to a scalar. The override is +intentional (LIME wouldn't otherwise work with the unified predict +path), but a silent override is surprising — users were left wondering +why the keyword had no effect. The wrapper now emits a UserWarning when +a non-default ``mode`` is discarded. +""" + +import warnings + +import numpy as np +import pytest + +pytest.importorskip("lime") + +from interpret.blackbox import LimeTabular + + +def _toy_data(): + rng = np.random.default_rng(0) + return rng.standard_normal((20, 3)).astype(np.float64) + + +def _toy_predict_fn(X): + # Probability of class 1 — what LimeTabular's unify_predict_fn yields + # for a binary classifier. + return 1.0 / (1.0 + np.exp(-X.sum(axis=1))) + + +def test_user_supplied_mode_classification_warns(): + # BEFORE: passing mode="classification" was silently overridden. + # AFTER: a UserWarning is emitted that mentions the override. + data = _toy_data() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + LimeTabular( + model=_toy_predict_fn, + data=data, + mode="classification", + ) + user_warnings = [w for w in caught if issubclass(w.category, UserWarning)] + assert len(user_warnings) == 1, [str(w.message) for w in caught] + assert "regression" in str(user_warnings[0].message).lower() + assert "classification" in str(user_warnings[0].message).lower() + + +def test_no_warning_when_mode_omitted(): + # The default path (no mode kwarg) must stay quiet — we don't want to + # nag every user who follows the documented API. + data = _toy_data() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + LimeTabular(model=_toy_predict_fn, data=data) + user_warnings = [ + w + for w in caught + if issubclass(w.category, UserWarning) and "mode=" in str(w.message) + ] + assert user_warnings == [] + + +def test_no_warning_when_mode_regression(): + # Explicitly passing the value we use internally must also stay quiet. + data = _toy_data() + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + LimeTabular(model=_toy_predict_fn, data=data, mode="regression") + user_warnings = [ + w + for w in caught + if issubclass(w.category, UserWarning) and "mode=" in str(w.message) + ] + assert user_warnings == []