Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion python/interpret-core/interpret/blackbox/_lime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023 The InterpretML Contributors
# Distributed under the MIT software license

import warnings

import numpy as np

from ..core.base import LocalExplainer
Expand Down Expand Up @@ -46,8 +48,26 @@ def __init__(self, model, data, feature_names=None, feature_types=None, **kwargs
# so convert to np.float64 until we implement some automatic categorical handling
data = data.astype(np.float64, order="C", copy=False)

# rewrite these even if the user specified them
# `LimeTabular` always runs the underlying LIME explainer in
# "regression" mode regardless of whether the wrapped model is a
# regressor or a binary classifier — `unify_predict_fn` below
# turns classifier outputs into a scalar probability in [0, 1] so
# LIME treats both cases uniformly. The overrides for `mode` and
# `feature_names` are therefore intentional, but issue #477 showed
# they are surprising when a user sets `mode="classification"`
# explicitly. Warn if their value is being discarded so they know
# to read the docstring rather than wonder why nothing happened.
kwargs = kwargs.copy()
user_mode = kwargs.get("mode", "regression")
if user_mode != "regression":
warnings.warn(
"LimeTabular wraps LIME in 'regression' mode internally for "
"both regression and binary classification (the predict "
"function is unified to return a scalar). The "
f"`mode={user_mode!r}` argument was ignored.",
UserWarning,
stacklevel=2,
)
kwargs["mode"] = "regression"
kwargs["feature_names"] = self.feature_names_in_

Expand Down
79 changes: 79 additions & 0 deletions python/interpret-core/tests/blackbox/test_lime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2026 The InterpretML Contributors
# Distributed under the MIT software license

"""Tests for LimeTabular wrapper.

Issue #477: passing ``mode="classification"`` to ``LimeTabular`` was
silently overridden to ``"regression"`` because the wrapper unifies
classifier and regressor predict functions to a scalar. The override is
intentional (LIME wouldn't otherwise work with the unified predict
path), but a silent override is surprising — users were left wondering
why the keyword had no effect. The wrapper now emits a UserWarning when
a non-default ``mode`` is discarded.
"""

import warnings

import numpy as np
import pytest

pytest.importorskip("lime")

from interpret.blackbox import LimeTabular


def _toy_data():
rng = np.random.default_rng(0)
return rng.standard_normal((20, 3)).astype(np.float64)


def _toy_predict_fn(X):
# Probability of class 1 — what LimeTabular's unify_predict_fn yields
# for a binary classifier.
return 1.0 / (1.0 + np.exp(-X.sum(axis=1)))


def test_user_supplied_mode_classification_warns():
# BEFORE: passing mode="classification" was silently overridden.
# AFTER: a UserWarning is emitted that mentions the override.
data = _toy_data()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
LimeTabular(
model=_toy_predict_fn,
data=data,
mode="classification",
)
user_warnings = [w for w in caught if issubclass(w.category, UserWarning)]
assert len(user_warnings) == 1, [str(w.message) for w in caught]
assert "regression" in str(user_warnings[0].message).lower()
assert "classification" in str(user_warnings[0].message).lower()


def test_no_warning_when_mode_omitted():
# The default path (no mode kwarg) must stay quiet — we don't want to
# nag every user who follows the documented API.
data = _toy_data()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
LimeTabular(model=_toy_predict_fn, data=data)
user_warnings = [
w
for w in caught
if issubclass(w.category, UserWarning) and "mode=" in str(w.message)
]
assert user_warnings == []


def test_no_warning_when_mode_regression():
# Explicitly passing the value we use internally must also stay quiet.
data = _toy_data()
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
LimeTabular(model=_toy_predict_fn, data=data, mode="regression")
user_warnings = [
w
for w in caught
if issubclass(w.category, UserWarning) and "mode=" in str(w.message)
]
assert user_warnings == []
Loading