Skip to content

Commit a1f07de

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Linear Model Updates (#524)
Summary: Updating linear model interface for Lime and KernelSHAP to use new linear models. Pull Request resolved: #524 Reviewed By: NarineK Differential Revision: D24923543 Pulled By: vivekmig fbshipit-source-id: 06f9b8c171550f71ba3b885999948f42d03e6d9d
1 parent a8a49f7 commit a1f07de

File tree

6 files changed

+94
-119
lines changed

6 files changed

+94
-119
lines changed

captum/_utils/models/linear_model/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ class SkLearnRidge(SkLearnLinearModel):
279279
def __init__(self, **kwargs):
280280
r"""
281281
Factory class. Trains a model with `sklearn.linear_model.Ridge`.
282+
283+
Any arguments provided to the sklearn constructor can be provided
284+
as kwargs here.
282285
"""
283286
super().__init__(**kwargs, sklearn_module="linear_model.Ridge")
284287

@@ -290,6 +293,9 @@ class SkLearnLinearRegression(SkLearnLinearModel):
290293
def __init__(self, **kwargs):
291294
r"""
292295
Factory class. Trains a model with `sklearn.linear_model.LinearRegression`.
296+
297+
Any arguments provided to the sklearn constructor can be provided
298+
as kwargs here.
293299
"""
294300
super().__init__(**kwargs, sklearn_module="linear_model.LinearRegression")
295301

captum/attr/_core/kernel_shap.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,12 @@
66
import torch
77
from torch import Tensor
88

9+
from captum._utils.models.linear_model import SkLearnLinearRegression
910
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
1011
from captum.attr._core.lime import Lime
1112
from captum.log import log_usage
1213

1314

14-
def linear_regression_interpretable_model_trainer(
15-
interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs
16-
):
17-
try:
18-
from sklearn import linear_model
19-
except ImportError:
20-
raise AssertionError(
21-
"Requires sklearn for default interpretable model training with linear "
22-
"regression. Please install sklearn or use a custom interpretable model "
23-
"training function."
24-
)
25-
clf = linear_model.LinearRegression()
26-
clf.fit(
27-
interpretable_inputs.cpu().numpy(),
28-
expected_outputs.cpu().numpy(),
29-
weights.cpu().numpy(),
30-
)
31-
return torch.from_numpy(clf.coef_)
32-
33-
3415
def combination(n: int, k: int) -> int:
3516
try:
3617
# Combination only available in Python 3.8
@@ -86,7 +67,7 @@ def __init__(self, forward_func: Callable) -> None:
8667
Lime.__init__(
8768
self,
8869
forward_func,
89-
linear_regression_interpretable_model_trainer,
70+
SkLearnLinearRegression(),
9071
kernel_shap_similarity_kernel,
9172
)
9273

captum/attr/_core/lime.py

Lines changed: 55 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torch import Tensor
99
from torch.nn import CosineSimilarity
10+
from torch.utils.data import DataLoader, TensorDataset
1011

1112
from captum._utils.common import (
1213
_expand_additional_forward_args,
@@ -18,6 +19,8 @@
1819
_reduce_list,
1920
_run_forward,
2021
)
22+
from captum._utils.models.linear_model import SkLearnLasso
23+
from captum._utils.models.model import Model
2124
from captum._utils.typing import (
2225
BaselineType,
2326
Literal,
@@ -66,7 +69,7 @@ class LimeBase(PerturbationAttribution):
6669
def __init__(
6770
self,
6871
forward_func: Callable,
69-
train_interpretable_model_func: Callable,
72+
interpretable_model: Model,
7073
similarity_func: Callable,
7174
perturb_func: Callable,
7275
perturb_interpretable_space: bool,
@@ -82,22 +85,25 @@ def __init__(
8285
modification of it. If a batch is provided as input for
8386
attribution, it is expected that forward_func returns a scalar
8487
representing the entire batch.
85-
train_interpretable_model_func (callable): Function which trains
86-
an interpretable model and returns some representation of the
87-
interpretable model. The return type of this will match the
88-
returned type when calling attribute.
89-
The expected signature of this callable is:
90-
91-
train_interpretable_model_func(
88+
interpretable_model (Model): Model object to train interpretable model.
89+
A Model object provides a `fit` method to train the model,
90+
given a dataloader, with batches containing three tensors:
9291
interpretable_inputs: Tensor
9392
[2D num_samples x num_interp_features],
9493
expected_outputs: Tensor [1D num_samples],
9594
weights: Tensor [1D num_samples]
96-
**kwargs: Any
97-
) -> Any (Representation of interpretable model)
9895
99-
All kwargs passed to the attribute method are
100-
provided as keyword arguments (kwargs) to this callable.
96+
The model object must also provide a `representation` method to
97+
access the appropriate coefficients or representation of the
98+
interpretable model after fitting.
99+
Some predefined interpretable linear models are provided in
100+
captum._utils.models.linear_model including wrappers around
101+
SkLearn linear models as well as SGD-based PyTorch linear
102+
models.
103+
104+
Note that calling fit multiple times should retrain the
105+
interpretable model, each attribution call reuses
106+
the same given interpretable model object.
101107
similarity_func (callable): Function which takes a single sample
102108
along with its corresponding interpretable representation
103109
and returns the weight of the interpretable sample for
@@ -204,7 +210,7 @@ def __init__(
204210
provided as keyword arguments (kwargs) to this callable.
205211
"""
206212
PerturbationAttribution.__init__(self, forward_func)
207-
self.train_interpretable_model_func = train_interpretable_model_func
213+
self.interpretable_model = interpretable_model
208214
self.similarity_func = similarity_func
209215
self.perturb_func = perturb_func
210216
self.perturb_interpretable_space = perturb_interpretable_space
@@ -230,7 +236,7 @@ def attribute(
230236
n_perturb_samples: int = 50,
231237
perturbations_per_eval: int = 1,
232238
**kwargs
233-
) -> TensorOrTupleOfTensorsGeneric:
239+
) -> Tensor:
234240
r"""
235241
This method attributes the output of the model with given target index
236242
(in case it is provided, otherwise it assumes that output is a
@@ -342,20 +348,12 @@ def attribute(
342348
>>> # score of the target class.
343349
>>>
344350
>>> # For interpretable model training, we will use sklearn
345-
>>> # in this example
346-
>>> from sklearn import linear_model
347-
>>>
348-
>>> # Define interpretable model training function
349-
>>> def linear_regression_interpretable_model_trainer(
350-
>>> interpretable_inputs: Tensor,
351-
>>> expected_outputs: Tensor,
352-
>>> weights: Tensor, **kwargs):
353-
>>> clf = linear_model.LinearRegression()
354-
>>> clf.fit(
355-
>>> interpretable_inputs.cpu().numpy(),
356-
>>> expected_outputs.cpu().numpy(),
357-
>>> weights.cpu().numpy())
358-
>>> return clf.coef_
351+
>>> # linear model in this example. We have provided wrappers
352+
>>> # around sklearn linear models to fit the Model interface.
353+
>>> # Any arguments provided to the sklearn constructor can also
354+
>>> # be provided to the wrapper, e.g.:
355+
>>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
356+
>>> from captum._utils.models.linear_model import SkLearnLinearModel
359357
>>>
360358
>>>
361359
>>> # Define similarity kernel (exponential kernel based on L2 norm)
@@ -387,7 +385,7 @@ def attribute(
387385
>>> input = torch.randn(2, 5)
388386
>>> # Defining LimeBase interpreter
389387
>>> lime_attr = LimeBase(net,
390-
linear_regression_interpretable_model_trainer,
388+
SkLearnLinearModel("linear_model.Ridge"),
391389
similarity_func=similarity_kernel,
392390
perturb_func=perturb_func,
393391
perturb_interpretable_space=False,
@@ -477,10 +475,13 @@ def attribute(
477475
if len(similarities[0].shape) > 0
478476
else torch.stack(similarities)
479477
)
480-
interp_model = self.train_interpretable_model_func(
481-
combined_interp_inps, combined_outputs, combined_sim, **kwargs
478+
dataset = TensorDataset(
479+
combined_interp_inps, combined_outputs, combined_sim
482480
)
483-
return interp_model
481+
self.interpretable_model.fit(
482+
DataLoader(dataset, batch_size=n_perturb_samples)
483+
)
484+
return self.interpretable_model.representation()
484485

485486
def _evaluate_batch(
486487
self,
@@ -516,32 +517,6 @@ def multiplies_by_inputs(self):
516517
# for Lime child implementation.
517518

518519

519-
def lasso_interpretable_model_trainer(
520-
interpretable_inputs: Tensor, expected_outputs: Tensor, weights: Tensor, **kwargs
521-
):
522-
try:
523-
import sklearn
524-
from sklearn import linear_model
525-
526-
assert (
527-
sklearn.__version__ >= "0.23.0"
528-
), "Must have sklearn version 0.23.0 or higher to use "
529-
"sample_weight in Lasso regression."
530-
except ImportError:
531-
raise AssertionError(
532-
"Requires sklearn for default interpretable model training with"
533-
" Lasso regression. Please install sklearn or use a custom interpretable"
534-
" model training function."
535-
)
536-
clf = linear_model.Lasso(alpha=kwargs["alpha"] if "alpha" in kwargs else 1.0)
537-
clf.fit(
538-
interpretable_inputs.cpu().numpy(),
539-
expected_outputs.cpu().numpy(),
540-
weights.cpu().numpy(),
541-
)
542-
return torch.from_numpy(clf.coef_)
543-
544-
545520
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
546521
assert (
547522
"feature_mask" in kwargs
@@ -665,7 +640,7 @@ class Lime(LimeBase):
665640
def __init__(
666641
self,
667642
forward_func: Callable,
668-
train_interpretable_model_func: Callable = lasso_interpretable_model_trainer,
643+
train_interpretable_model_func: Model = SkLearnLasso(alpha=1.0),
669644
similarity_func: Callable = get_exp_kernel_similarity_function(),
670645
perturb_func: Callable = default_perturb_func,
671646
) -> None:
@@ -676,31 +651,31 @@ def __init__(
676651
677652
forward_func (callable): The forward function of the model or any
678653
modification of it
679-
train_interpretable_model_func (optional, callable): Function which
680-
trains an interpretable model and returns coefficients
681-
of the interpretable model.
682-
This function is optional, and the default function trains
683-
an interpretable model using Lasso regression, using the
684-
alpha parameter provided when calling attribute.
685-
Using the default function requires having sklearn version
686-
0.23.0 or higher installed.
687-
688-
If a custom function is provided, the expected signature of this
689-
callable is:
690-
691-
train_interpretable_model_func(
654+
interpretable_model (optional, Model): Model object to train
655+
interpretable model.
656+
657+
This argument is optional and defaults to SkLearnLasso(alpha=1.0),
658+
which is a wrapper around the Lasso linear model in SkLearn.
659+
This requires having sklearn version >= 0.23 available.
660+
661+
Other predefined interpretable linear models are provided in
662+
captum._utils.models.linear_model.
663+
664+
Alternatively, a custom model object must provide a `fit` method to
665+
train the model, given a dataloader, with batches containing
666+
three tensors:
692667
interpretable_inputs: Tensor
693668
[2D num_samples x num_interp_features],
694669
expected_outputs: Tensor [1D num_samples],
695670
weights: Tensor [1D num_samples]
696-
**kwargs: Any
697-
) -> Tensor [1D num_interp_features]
698-
The return type must be a 1D tensor containing the importance
699-
or attribution of each input feature.
700671
701-
kwargs includes baselines, feature_mask, num_interp_features
702-
(integer, determined from feature mask), and
703-
alpha (for Lasso regression).
672+
The model object must also provide a `representation` method to
673+
access the appropriate coefficients or representation of the
674+
interpretable model after fitting.
675+
676+
Note that calling fit multiple times should retrain the
677+
interpretable model, each attribution call reuses
678+
the same given interpretable model object.
704679
similarity_func (callable): Function which takes a single sample
705680
along with its corresponding interpretable representation
706681
and returns the weight of the interpretable sample for
@@ -764,7 +739,6 @@ def attribute( # type: ignore
764739
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
765740
n_perturb_samples: int = 25,
766741
perturbations_per_eval: int = 1,
767-
alpha: float = 1.0,
768742
return_input_shape: bool = True,
769743
) -> TensorOrTupleOfTensorsGeneric:
770744
r"""
@@ -1058,7 +1032,6 @@ def attribute( # type: ignore
10581032
if is_inputs_tuple
10591033
else curr_feature_mask[0],
10601034
num_interp_features=num_interp_features,
1061-
alpha=alpha,
10621035
)
10631036
if return_input_shape:
10641037
output_list.append(
@@ -1095,7 +1068,6 @@ def attribute( # type: ignore
10951068
baselines=baselines if is_inputs_tuple else baselines[0],
10961069
feature_mask=feature_mask if is_inputs_tuple else feature_mask[0],
10971070
num_interp_features=num_interp_features,
1098-
alpha=alpha,
10991071
)
11001072
if return_input_shape:
11011073
return self._convert_output_shape(
@@ -1138,6 +1110,7 @@ def _convert_output_shape(
11381110
num_interp_features: int,
11391111
is_inputs_tuple: bool,
11401112
) -> Union[Tensor, Tuple[Tensor, ...]]:
1113+
coefs = coefs.flatten()
11411114
attr = [torch.zeros_like(single_inp) for single_inp in formatted_inp]
11421115
for tensor_ind in range(len(formatted_inp)):
11431116
for single_feature in range(num_interp_features):

sphinx/source/utilities.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,24 @@ Token Reference Base
2424

2525
.. autoclass:: captum.attr.TokenReferenceBase
2626
:members:
27+
28+
29+
Linear Models
30+
^^^^^^^^^^^^^^^^^^^^^
31+
32+
.. autoclass:: captum._utils.models.model.Model
33+
:members:
34+
.. autoclass:: captum._utils.models.linear_model.SkLearnLinearModel
35+
:members:
36+
.. autoclass:: captum._utils.models.linear_model.SkLearnLinearRegression
37+
:members:
38+
.. autoclass:: captum._utils.models.linear_model.SkLearnLasso
39+
:members:
40+
.. autoclass:: captum._utils.models.linear_model.SkLearnRidge
41+
:members:
42+
.. autoclass:: captum._utils.models.linear_model.SGDLinearModel
43+
:members:
44+
.. autoclass:: captum._utils.models.linear_model.SGDLasso
45+
:members:
46+
.. autoclass:: captum._utils.models.linear_model.SGDRidge
47+
:members:

tests/attr/test_kernel_shap.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def setUp(self) -> None:
2323
super().setUp()
2424
try:
2525
import sklearn # noqa: F401
26-
except ImportError:
27-
raise unittest.SkipTest(
28-
"Skipping Kernel Shap tests, sklearn not available."
29-
)
26+
27+
assert (
28+
sklearn.__version__ >= "0.23.0"
29+
), "Must have sklearn version 0.23.0 or higher"
30+
except (ImportError, AssertionError):
31+
raise unittest.SkipTest("Skipping KernelShap tests, sklearn not available.")
3032

3133
def test_linear_kernel_shap(self) -> None:
3234
net = BasicModel_MultiLayer()

0 commit comments

Comments
 (0)