@@ -88,10 +88,11 @@ def __init__(
8888 interpretable_model (Model): Model object to train interpretable model.
8989 A Model object provides a `fit` method to train the model,
9090 given a dataloader, with batches containing three tensors:
91- interpretable_inputs: Tensor
92- [2D num_samples x num_interp_features],
93- expected_outputs: Tensor [1D num_samples],
94- weights: Tensor [1D num_samples]
91+
92+ - interpretable_inputs: Tensor
93+ [2D num_samples x num_interp_features],
94+ - expected_outputs: Tensor [1D num_samples],
95+ - weights: Tensor [1D num_samples]
9596
9697 The model object must also provide a `representation` method to
9798 access the appropriate coefficients or representation of the
@@ -113,16 +114,16 @@ def __init__(
113114
114115 The expected signature of this callable is:
115116
116- similarity_func(
117- original_input: Tensor or tuple of Tensors,
118- perturbed_input: Tensor or tuple of Tensors,
119- perturbed_interpretable_input:
120- Tensor [2D 1 x num_interp_features],
121- **kwargs: Any
122- ) -> float or Tensor containing float scalar
117+ >>> similarity_func(
118+ >>> original_input: Tensor or tuple of Tensors,
119+ >>> perturbed_input: Tensor or tuple of Tensors,
120+ >>> perturbed_interpretable_input:
121+ >>> Tensor [2D 1 x num_interp_features],
122+ >>> **kwargs: Any
123+ >>> ) -> float or Tensor containing float scalar
123124
124125 perturbed_input and original_input will be the same type and
125- contain tensors of the same shape (regardless of whether
126+ contain tensors of the same shape (regardless of whether or not
126127 the sampling function returns inputs in the interpretable
127128 space). original_input is the same as the input provided
128129 when calling attribute.
@@ -139,10 +140,10 @@ def __init__(
139140
140141 The expected signature of this callable is:
141142
142- perturb_func(
143- original_input: Tensor or tuple of Tensors,
144- **kwargs: Any
145- ) -> Tensor or tuple of Tensors
143+ >>> perturb_func(
144+ >>> original_input: Tensor or tuple of Tensors,
145+ >>> **kwargs: Any
146+ >>> ) -> Tensor or tuple of Tensors
146147
147148 All kwargs passed to the attribute method are
148149 provided as keyword arguments (kwargs) to this callable.
@@ -175,11 +176,11 @@ def __init__(
175176
176177 The expected signature of this callable is:
177178
178- from_interp_rep_transform(
179- curr_sample: Tensor [2D 1 x num_interp_features]
180- original_input: Tensor or Tuple of Tensors,
181- **kwargs: Any
182- ) -> Tensor or tuple of Tensors
179+ >>> from_interp_rep_transform(
180+ >>> curr_sample: Tensor [2D 1 x num_interp_features]
181+ >>> original_input: Tensor or Tuple of Tensors,
182+ >>> **kwargs: Any
183+ >>> ) -> Tensor or tuple of Tensors
183184
184185 Returned sampled input should match the type of original_input
185186 and corresponding tensor shapes.
@@ -197,11 +198,11 @@ def __init__(
197198
198199 The expected signature of this callable is:
199200
200- to_interp_rep_transform(
201- curr_sample: Tensor or Tuple of Tensors,
202- original_input: Tensor or Tuple of Tensors,
203- **kwargs: Any
204- ) -> Tensor [2D 1 x num_interp_features]
201+ >>> to_interp_rep_transform(
202+ >>> curr_sample: Tensor or Tuple of Tensors,
203+ >>> original_input: Tensor or Tuple of Tensors,
204+ >>> **kwargs: Any
205+ >>> ) -> Tensor [2D 1 x num_interp_features]
205206
206207 curr_sample will match the type of original_input
207208 and corresponding tensor shapes.
@@ -640,9 +641,9 @@ class Lime(LimeBase):
640641 def __init__ (
641642 self ,
642643 forward_func : Callable ,
643- train_interpretable_model_func : Model = SkLearnLasso ( alpha = 1.0 ) ,
644- similarity_func : Callable = get_exp_kernel_similarity_function () ,
645- perturb_func : Callable = default_perturb_func ,
644+ interpretable_model : Optional [ Model ] = None ,
645+ similarity_func : Optional [ Callable ] = None ,
646+ perturb_func : Optional [ Callable ] = None ,
646647 ) -> None :
647648 r"""
648649
@@ -664,10 +665,11 @@ def __init__(
664665 Alternatively, a custom model object must provide a `fit` method to
665666 train the model, given a dataloader, with batches containing
666667 three tensors:
667- interpretable_inputs: Tensor
668- [2D num_samples x num_interp_features],
669- expected_outputs: Tensor [1D num_samples],
670- weights: Tensor [1D num_samples]
668+
669+ - interpretable_inputs: Tensor
670+ [2D num_samples x num_interp_features],
671+ - expected_outputs: Tensor [1D num_samples],
672+ - weights: Tensor [1D num_samples]
671673
672674 The model object must also provide a `representation` method to
673675 access the appropriate coefficients or representation of the
@@ -676,52 +678,72 @@ def __init__(
676678 Note that calling fit multiple times should retrain the
677679 interpretable model, each attribution call reuses
678680 the same given interpretable model object.
679- similarity_func (callable): Function which takes a single sample
681+ similarity_func (optional, callable): Function which takes a single sample
680682 along with its corresponding interpretable representation
681683 and returns the weight of the interpretable sample for
682684 training the interpretable model.
683685 This is often referred to as a similarity kernel.
684686
687+ This argument is optional and defaults to a function which
688+ applies an exponential kernel to the consine distance between
689+ the original input and perturbed input, with a kernel width
690+ of 1.0.
691+
692+ A similarity function applying an exponential
693+ kernel to cosine / euclidean distances can be constructed
694+ using the provided get_exp_kernel_similarity_function in
695+ captum.attr._core.lime.
696+
697+ Alternately, a custom callable can also be provided.
685698 The expected signature of this callable is:
686699
687- similarity_func(
688- original_input: Tensor or tuple of Tensors,
689- perturbed_input: Tensor or tuple of Tensors,
690- perturbed_interpretable_input:
691- Tensor [2D 1 x num_interp_features],
692- **kwargs: Any
693- ) -> float or Tensor containing float scalar
700+ >>> def similarity_func(
701+ >>> original_input: Tensor or tuple of Tensors,
702+ >>> perturbed_input: Tensor or tuple of Tensors,
703+ >>> perturbed_interpretable_input:
704+ >>> Tensor [2D 1 x num_interp_features],
705+ >>> **kwargs: Any
706+ >>> ) -> float or Tensor containing float scalar
694707
695708 perturbed_input and original_input will be the same type and
696709 contain tensors of the same shape, with original_input
697710 being the same as the input provided when calling attribute.
698711
699712 kwargs includes baselines, feature_mask, num_interp_features
700- (integer, determined from feature mask), and
701- alpha (for Lasso regression).
702- perturb_func (callable): Function which returns a single
713+ (integer, determined from feature mask).
714+ perturb_func (optional, callable): Function which returns a single
703715 sampled input, which is a binary vector of length
704- num_interp_features. The default function returns
716+ num_interp_features.
717+
718+ This function is optional, the default function returns
705719 a binary vector where each element is selected
706720 independently and uniformly at random. Custom
707721 logic for selecting sampled binary vectors can
708722 be implemented by providing a function with the
709723 following expected signature:
710724
711- perturb_func(
712- original_input: Tensor or tuple of Tensors,
713- **kwargs: Any
714- ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
725+ >>> perturb_func(
726+ >>> original_input: Tensor or tuple of Tensors,
727+ >>> **kwargs: Any
728+ >>> ) -> Tensor [Binary 2D Tensor 1 x num_interp_features]
715729
716730 kwargs includes baselines, feature_mask, num_interp_features
717- (integer, determined from feature mask), and
718- alpha (for Lasso regression).
731+ (integer, determined from feature mask).
719732
720733 """
734+ if interpretable_model is None :
735+ interpretable_model = SkLearnLasso (alpha = 1.0 )
736+
737+ if similarity_func is None :
738+ similarity_func = get_exp_kernel_similarity_function ()
739+
740+ if perturb_func is None :
741+ perturb_func = default_perturb_func
742+
721743 LimeBase .__init__ (
722744 self ,
723745 forward_func ,
724- train_interpretable_model_func ,
746+ interpretable_model ,
725747 similarity_func ,
726748 perturb_func ,
727749 True ,
@@ -788,27 +810,27 @@ def attribute( # type: ignore
788810 Baselines can be provided as:
789811
790812 - a single tensor, if inputs is a single tensor, with
791- exactly the same dimensions as inputs or the first
792- dimension is one and the remaining dimensions match
793- with inputs.
813+ exactly the same dimensions as inputs or the first
814+ dimension is one and the remaining dimensions match
815+ with inputs.
794816
795817 - a single scalar, if inputs is a single tensor, which will
796- be broadcasted for each input value in input tensor.
818+ be broadcasted for each input value in input tensor.
797819
798820 - a tuple of tensors or scalars, the baseline corresponding
799- to each tensor in the inputs' tuple can be:
821+ to each tensor in the inputs' tuple can be:
800822
801- - either a tensor with matching dimensions to
823+ - either a tensor with matching dimensions to
802824 corresponding tensor in the inputs' tuple
803825 or the first dimension is one and the remaining
804826 dimensions match with the corresponding
805827 input tensor.
806828
807- - or a scalar, corresponding to a tensor in the
829+ - or a scalar, corresponding to a tensor in the
808830 inputs' tuple. This scalar value is broadcasted
809831 for corresponding input tensor.
810832 In the cases when `baselines` is not provided, we internally
811- use zero corresponding to each input tensor.
833+ use zero scalar corresponding to each input tensor.
812834 Default: None
813835 target (int, tuple, tensor or list, optional): Output indices for
814836 which surrogate model is trained
@@ -886,10 +908,6 @@ def attribute( # type: ignore
886908 If the forward function returns a single scalar per batch,
887909 perturbations_per_eval must be set to 1.
888910 Default: 1
889- alpha (float, optional): Alpha used for training interpretable surrogate
890- model in Lasso Regression. This parameter is used only
891- if using default interpretable model trainer (Lasso).
892- Default: 1.0
893911 return_input_shape (bool, optional): Determines whether the returned
894912 tensor(s) only contain the coefficients for each interp-
895913 retable feature from the trained surrogate model, or
0 commit comments