77import torch
88from torch import Tensor
99from torch .nn import CosineSimilarity
10+ from torch .utils .data import DataLoader , TensorDataset
1011
1112from captum ._utils .common import (
1213 _expand_additional_forward_args ,
1819 _reduce_list ,
1920 _run_forward ,
2021)
22+ from captum ._utils .models .linear_model import SkLearnLasso
23+ from captum ._utils .models .model import Model
2124from captum ._utils .typing import (
2225 BaselineType ,
2326 Literal ,
@@ -66,7 +69,7 @@ class LimeBase(PerturbationAttribution):
6669 def __init__ (
6770 self ,
6871 forward_func : Callable ,
69- train_interpretable_model_func : Callable ,
72+ interpretable_model : Model ,
7073 similarity_func : Callable ,
7174 perturb_func : Callable ,
7275 perturb_interpretable_space : bool ,
@@ -82,22 +85,25 @@ def __init__(
8285 modification of it. If a batch is provided as input for
8386 attribution, it is expected that forward_func returns a scalar
8487 representing the entire batch.
85- train_interpretable_model_func (callable): Function which trains
86- an interpretable model and returns some representation of the
87- interpretable model. The return type of this will match the
88- returned type when calling attribute.
89- The expected signature of this callable is:
90-
91- train_interpretable_model_func(
88+ interpretable_model (Model): Model object to train interpretable model.
89+ A Model object provides a `fit` method to train the model,
90+ given a dataloader, with batches containing three tensors:
9291 interpretable_inputs: Tensor
9392 [2D num_samples x num_interp_features],
9493 expected_outputs: Tensor [1D num_samples],
9594 weights: Tensor [1D num_samples]
96- **kwargs: Any
97- ) -> Any (Representation of interpretable model)
9895
99- All kwargs passed to the attribute method are
100- provided as keyword arguments (kwargs) to this callable.
96+ The model object must also provide a `representation` method to
97+ access the appropriate coefficients or representation of the
98+ interpretable model after fitting.
99+ Some predefined interpretable linear models are provided in
100+ captum._utils.models.linear_model including wrappers around
101+ SkLearn linear models as well as SGD-based PyTorch linear
102+ models.
103+
104+ Note that calling fit multiple times should retrain the
105+ interpretable model, each attribution call reuses
106+ the same given interpretable model object.
101107 similarity_func (callable): Function which takes a single sample
102108 along with its corresponding interpretable representation
103109 and returns the weight of the interpretable sample for
@@ -204,7 +210,7 @@ def __init__(
204210 provided as keyword arguments (kwargs) to this callable.
205211 """
206212 PerturbationAttribution .__init__ (self , forward_func )
207- self .train_interpretable_model_func = train_interpretable_model_func
213+ self .interpretable_model = interpretable_model
208214 self .similarity_func = similarity_func
209215 self .perturb_func = perturb_func
210216 self .perturb_interpretable_space = perturb_interpretable_space
@@ -230,7 +236,7 @@ def attribute(
230236 n_perturb_samples : int = 50 ,
231237 perturbations_per_eval : int = 1 ,
232238 ** kwargs
233- ) -> TensorOrTupleOfTensorsGeneric :
239+ ) -> Tensor :
234240 r"""
235241 This method attributes the output of the model with given target index
236242 (in case it is provided, otherwise it assumes that output is a
@@ -342,20 +348,12 @@ def attribute(
342348 >>> # score of the target class.
343349 >>>
344350 >>> # For interpretable model training, we will use sklearn
345- >>> # in this example
346- >>> from sklearn import linear_model
347- >>>
348- >>> # Define interpretable model training function
349- >>> def linear_regression_interpretable_model_trainer(
350- >>> interpretable_inputs: Tensor,
351- >>> expected_outputs: Tensor,
352- >>> weights: Tensor, **kwargs):
353- >>> clf = linear_model.LinearRegression()
354- >>> clf.fit(
355- >>> interpretable_inputs.cpu().numpy(),
356- >>> expected_outputs.cpu().numpy(),
357- >>> weights.cpu().numpy())
358- >>> return clf.coef_
351+ >>> # linear model in this example. We have provided wrappers
352+ >>> # around sklearn linear models to fit the Model interface.
353+ >>> # Any arguments provided to the sklearn constructor can also
354+ >>> # be provided to the wrapper, e.g.:
355+ >>> # SkLearnLinearModel("linear_model.Ridge", alpha=2.0)
356+ >>> from captum._utils.models.linear_model import SkLearnLinearModel
359357 >>>
360358 >>>
361359 >>> # Define similarity kernel (exponential kernel based on L2 norm)
@@ -387,7 +385,7 @@ def attribute(
387385 >>> input = torch.randn(2, 5)
388386 >>> # Defining LimeBase interpreter
389387 >>> lime_attr = LimeBase(net,
390- linear_regression_interpretable_model_trainer ,
388+ SkLearnLinearModel("linear_model.Ridge") ,
391389 similarity_func=similarity_kernel,
392390 perturb_func=perturb_func,
393391 perturb_interpretable_space=False,
@@ -477,10 +475,13 @@ def attribute(
477475 if len (similarities [0 ].shape ) > 0
478476 else torch .stack (similarities )
479477 )
480- interp_model = self . train_interpretable_model_func (
481- combined_interp_inps , combined_outputs , combined_sim , ** kwargs
478+ dataset = TensorDataset (
479+ combined_interp_inps , combined_outputs , combined_sim
482480 )
483- return interp_model
481+ self .interpretable_model .fit (
482+ DataLoader (dataset , batch_size = n_perturb_samples )
483+ )
484+ return self .interpretable_model .representation ()
484485
485486 def _evaluate_batch (
486487 self ,
@@ -516,32 +517,6 @@ def multiplies_by_inputs(self):
516517# for Lime child implementation.
517518
518519
519- def lasso_interpretable_model_trainer (
520- interpretable_inputs : Tensor , expected_outputs : Tensor , weights : Tensor , ** kwargs
521- ):
522- try :
523- import sklearn
524- from sklearn import linear_model
525-
526- assert (
527- sklearn .__version__ >= "0.23.0"
528- ), "Must have sklearn version 0.23.0 or higher to use "
529- "sample_weight in Lasso regression."
530- except ImportError :
531- raise AssertionError (
532- "Requires sklearn for default interpretable model training with"
533- " Lasso regression. Please install sklearn or use a custom interpretable"
534- " model training function."
535- )
536- clf = linear_model .Lasso (alpha = kwargs ["alpha" ] if "alpha" in kwargs else 1.0 )
537- clf .fit (
538- interpretable_inputs .cpu ().numpy (),
539- expected_outputs .cpu ().numpy (),
540- weights .cpu ().numpy (),
541- )
542- return torch .from_numpy (clf .coef_ )
543-
544-
545520def default_from_interp_rep_transform (curr_sample , original_inputs , ** kwargs ):
546521 assert (
547522 "feature_mask" in kwargs
@@ -665,7 +640,7 @@ class Lime(LimeBase):
665640 def __init__ (
666641 self ,
667642 forward_func : Callable ,
668- train_interpretable_model_func : Callable = lasso_interpretable_model_trainer ,
643+ train_interpretable_model_func : Model = SkLearnLasso ( alpha = 1.0 ) ,
669644 similarity_func : Callable = get_exp_kernel_similarity_function (),
670645 perturb_func : Callable = default_perturb_func ,
671646 ) -> None :
@@ -676,31 +651,31 @@ def __init__(
676651
677652 forward_func (callable): The forward function of the model or any
678653 modification of it
679- train_interpretable_model_func (optional, callable ): Function which
680- trains an interpretable model and returns coefficients
681- of the interpretable model.
682- This function is optional, and the default function trains
683- an interpretable model using Lasso regression, using the
684- alpha parameter provided when calling attribute .
685- Using the default function requires having sklearn version
686- 0.23.0 or higher installed.
687-
688- If a custom function is provided, the expected signature of this
689- callable is:
690-
691- train_interpretable_model_func(
654+ interpretable_model (optional, Model ): Model object to train
655+ interpretable model.
656+
657+ This argument is optional and defaults to SkLearnLasso(alpha=1.0),
658+ which is a wrapper around the Lasso linear model in SkLearn.
659+ This requires having sklearn version >= 0.23 available .
660+
661+ Other predefined interpretable linear models are provided in
662+ captum._utils.models.linear_model.
663+
664+ Alternatively, a custom model object must provide a `fit` method to
665+ train the model, given a dataloader, with batches containing
666+ three tensors:
692667 interpretable_inputs: Tensor
693668 [2D num_samples x num_interp_features],
694669 expected_outputs: Tensor [1D num_samples],
695670 weights: Tensor [1D num_samples]
696- **kwargs: Any
697- ) -> Tensor [1D num_interp_features]
698- The return type must be a 1D tensor containing the importance
699- or attribution of each input feature.
700671
701- kwargs includes baselines, feature_mask, num_interp_features
702- (integer, determined from feature mask), and
703- alpha (for Lasso regression).
672+ The model object must also provide a `representation` method to
673+ access the appropriate coefficients or representation of the
674+ interpretable model after fitting.
675+
676+ Note that calling fit multiple times should retrain the
677+ interpretable model, each attribution call reuses
678+ the same given interpretable model object.
704679 similarity_func (callable): Function which takes a single sample
705680 along with its corresponding interpretable representation
706681 and returns the weight of the interpretable sample for
@@ -764,7 +739,6 @@ def attribute( # type: ignore
764739 feature_mask : Union [None , Tensor , Tuple [Tensor , ...]] = None ,
765740 n_perturb_samples : int = 25 ,
766741 perturbations_per_eval : int = 1 ,
767- alpha : float = 1.0 ,
768742 return_input_shape : bool = True ,
769743 ) -> TensorOrTupleOfTensorsGeneric :
770744 r"""
@@ -1058,7 +1032,6 @@ def attribute( # type: ignore
10581032 if is_inputs_tuple
10591033 else curr_feature_mask [0 ],
10601034 num_interp_features = num_interp_features ,
1061- alpha = alpha ,
10621035 )
10631036 if return_input_shape :
10641037 output_list .append (
@@ -1095,7 +1068,6 @@ def attribute( # type: ignore
10951068 baselines = baselines if is_inputs_tuple else baselines [0 ],
10961069 feature_mask = feature_mask if is_inputs_tuple else feature_mask [0 ],
10971070 num_interp_features = num_interp_features ,
1098- alpha = alpha ,
10991071 )
11001072 if return_input_shape :
11011073 return self ._convert_output_shape (
@@ -1138,6 +1110,7 @@ def _convert_output_shape(
11381110 num_interp_features : int ,
11391111 is_inputs_tuple : bool ,
11401112 ) -> Union [Tensor , Tuple [Tensor , ...]]:
1113+ coefs = coefs .flatten ()
11411114 attr = [torch .zeros_like (single_inp ) for single_inp in formatted_inp ]
11421115 for tensor_ind in range (len (formatted_inp )):
11431116 for single_feature in range (num_interp_features ):
0 commit comments