Skip to content

Commit e0402b8

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Moving Docstrings (#309)
Summary: This moves algorithm descriptions from attribute method to the class in order to make API docs more intuitive. Pull Request resolved: #309 Reviewed By: NarineK Differential Revision: D20196721 Pulled By: vivekmig fbshipit-source-id: 5edb26089327754e2eb2dc2c6278dab4758ad107
1 parent 358750e commit e0402b8

25 files changed

+1603
-1546
lines changed

captum/attr/_core/deep_lift.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,37 @@ def is_output_cloned(output_fn, input_grad_fn) -> bool:
6868

6969

7070
class DeepLift(GradientAttribution):
71+
r"""
72+
Implements DeepLIFT algorithm based on the following paper:
73+
Learning Important Features Through Propagating Activation Differences,
74+
Avanti Shrikumar, et. al.
75+
https://arxiv.org/abs/1704.02685
76+
77+
and the gradient formulation proposed in:
78+
Towards better understanding of gradient-based attribution methods for
79+
deep neural networks, Marco Ancona, et.al.
80+
https://openreview.net/pdf?id=Sy21R9JAW
81+
82+
This implementation supports only Rescale rule. RevealCancel rule will
83+
be supported in later releases.
84+
In addition to that, in order to keep the implementation cleaner, DeepLIFT
85+
for internal neurons and layers extends current implementation and is
86+
implemented separately in LayerDeepLift and NeuronDeepLift.
87+
Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
88+
Integrated Gradients, it runs significantly faster than Integrated
89+
Gradients and is preferred for large datasets.
90+
91+
Currently we only support a limited number of non-linear activations
92+
but the plan is to expand the list in the future.
93+
94+
Note: As we know, currently we cannot access the building blocks,
95+
of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
96+
Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
97+
with performance similar to built-in ones using TorchScript.
98+
More details on how to build custom RNNs can be found here:
99+
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
100+
"""
101+
71102
def __init__(self, model: Module) -> None:
72103
r"""
73104
Args:
@@ -116,35 +147,6 @@ def attribute( # type: ignore
116147
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
117148
]:
118149
r""""
119-
Implements DeepLIFT algorithm based on the following paper:
120-
Learning Important Features Through Propagating Activation Differences,
121-
Avanti Shrikumar, et. al.
122-
https://arxiv.org/abs/1704.02685
123-
124-
and the gradient formulation proposed in:
125-
Towards better understanding of gradient-based attribution methods for
126-
deep neural networks, Marco Ancona, et.al.
127-
https://openreview.net/pdf?id=Sy21R9JAW
128-
129-
This implementation supports only Rescale rule. RevealCancel rule will
130-
be supported in later releases.
131-
In addition to that, in order to keep the implementation cleaner, DeepLIFT
132-
for internal neurons and layers extends current implementation and is
133-
implemented separately in LayerDeepLift and NeuronDeepLift.
134-
Although DeepLIFT's(Rescale Rule) attribution quality is comparable with
135-
Integrated Gradients, it runs significantly faster than Integrated
136-
Gradients and is preferred for large datasets.
137-
138-
Currently we only support a limited number of non-linear activations
139-
but the plan is to expand the list in the future.
140-
141-
Note: As we know, currently we cannot access the building blocks,
142-
of PyTorch's built-in LSTM, RNNs and GRUs such as Tanh and Sigmoid.
143-
Nonetheless, it is possible to build custom LSTMs, RNNS and GRUs
144-
with performance similar to built-in ones using TorchScript.
145-
More details on how to build custom RNNs can be found here:
146-
https://pytorch.org/blog/optimizing-cuda-rnn-with-torchscript/
147-
148150
Args:
149151
150152
inputs (tensor or tuple of tensors): Input for which
@@ -520,6 +522,22 @@ def has_convergence_delta(self) -> bool:
520522

521523

522524
class DeepLiftShap(DeepLift):
525+
r"""
526+
Extends DeepLift algorithm and approximates SHAP values using Deeplift.
527+
For each input sample it computes DeepLift attribution with respect to
528+
each baseline and averages resulting attributions.
529+
More details about the algorithm can be found here:
530+
531+
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
532+
533+
Note that the explanation model:
534+
1. Assumes that input features are independent of one another
535+
2. Is linear, meaning that the explanations are modeled through
536+
the additive composition of feature effects.
537+
Although, it assumes a linear model for each explanation, the overall
538+
model across multiple explanations can be complex and non-linear.
539+
"""
540+
523541
def __init__(self, model: Module) -> None:
524542
r"""
525543
Args:
@@ -573,20 +591,6 @@ def attribute( # type: ignore
573591
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
574592
]:
575593
r"""
576-
Extends DeepLift algorithm and approximates SHAP values using Deeplift.
577-
For each input sample it computes DeepLift attribution with respect to
578-
each baseline and averages resulting attributions.
579-
More details about the algorithm can be found here:
580-
581-
http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
582-
583-
Note that the explanation model:
584-
1. Assumes that input features are independent of one another
585-
2. Is linear, meaning that the explanations are modeled through
586-
the additive composition of feature effects.
587-
Although, it assumes a linear model for each explanation, the overall
588-
model across multiple explanations can be complex and non-linear.
589-
590594
Args:
591595
592596
inputs (tensor or tuple of tensors): Input for which

captum/attr/_core/feature_ablation.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,25 @@
2222

2323

2424
class FeatureAblation(PerturbationAttribution):
25+
r"""
26+
A perturbation based approach to computing attribution, involving
27+
replacing each input feature with a given baseline / reference, and
28+
computing the difference in output. By default, each scalar value within
29+
each input tensor is taken as a feature and replaced independently. Passing
30+
a feature mask, allows grouping features to be ablated together. This can
31+
be used in cases such as images, where an entire segment or region
32+
can be ablated, measuring the importance of the segment (feature group).
33+
Each input scalar in the group will be given the same attribution value
34+
equal to the change in target as a result of ablating the entire feature
35+
group.
36+
37+
The forward function can either return a scalar per example, or a single
38+
scalar for the full batch. If a single scalar is returned for the batch,
39+
`perturbations_per_eval` must be 1, and the returned attributions will have
40+
first dimension 1, corresponding to feature importance across all
41+
examples in the batch.
42+
"""
43+
2544
def __init__(self, forward_func: Callable) -> None:
2645
r"""
2746
Args:
@@ -43,23 +62,6 @@ def attribute(
4362
**kwargs: Any
4463
) -> TensorOrTupleOfTensorsGeneric:
4564
r""""
46-
A perturbation based approach to computing attribution, involving
47-
replacing each input feature with a given baseline / reference, and
48-
computing the difference in output. By default, each scalar value within
49-
each input tensor is taken as a feature and replaced independently. Passing
50-
a feature mask, allows grouping features to be ablated together. This can
51-
be used in cases such as images, where an entire segment or region
52-
can be ablated, measuring the importance of the segment (feature group).
53-
Each input scalar in the group will be given the same attribution value
54-
equal to the change in target as a result of ablating the entire feature
55-
group.
56-
57-
The forward function can either return a scalar per example, or a single
58-
scalar for the full batch. If a single scalar is returned for the batch,
59-
`perturbations_per_eval` must be 1, and the returned attributions will have
60-
first dimension 1, corresponding to feature importance across all
61-
examples in the batch.
62-
6365
Args:
6466
6567
inputs (tensor or tuple of tensors): Input for which ablation

captum/attr/_core/gradient_shap.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,41 @@
2424

2525

2626
class GradientShap(GradientAttribution):
27+
r"""
28+
Implements gradient SHAP based on the implementation from SHAP's primary
29+
author. For reference, please, view:
30+
31+
https://github.com/slundberg/shap\
32+
#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models
33+
34+
A Unified Approach to Interpreting Model Predictions
35+
http://papers.nips.cc/paper\
36+
7062-a-unified-approach-to-interpreting-model-predictions
37+
38+
GradientShap approximates SHAP values by computing the expectations of
39+
gradients by randomly sampling from the distribution of baselines/references.
40+
It adds white noise to each input sample `n_samples` times, selects a
41+
random baseline from baselines' distribution and a random point along the
42+
path between the baseline and the input, and computes the gradient of outputs
43+
with respect to those selected random points. The final SHAP values represent
44+
the expected values of gradients * (inputs - baselines).
45+
46+
GradientShap makes an assumption that the input features are independent
47+
and that the explanation model is linear, meaning that the explanations
48+
are modeled through the additive composition of feature effects.
49+
Under those assumptions, SHAP value can be approximated as the expectation
50+
of gradients that are computed for randomly generated `n_samples` input
51+
samples after adding gaussian noise `n_samples` times to each input for
52+
different baselines/references.
53+
54+
In some sense it can be viewed as an approximation of integrated gradients
55+
by computing the expectations of gradients for different baselines.
56+
57+
Current implementation uses Smoothgrad from `NoiseTunnel` in order to
58+
randomly draw samples from the distribution of baselines, add noise to input
59+
samples and compute the expectation (smoothgrad).
60+
"""
61+
2762
def __init__(self, forward_func: Callable) -> None:
2863
r"""
2964
Args:
@@ -79,39 +114,6 @@ def attribute(
79114
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
80115
]:
81116
r"""
82-
Implements gradient SHAP based on the implementation from SHAP's primary
83-
author. For reference, please, view:
84-
85-
https://github.com/slundberg/shap\
86-
#deep-learning-example-with-gradientexplainer-tensorflowkeraspytorch-models
87-
88-
A Unified Approach to Interpreting Model Predictions
89-
http://papers.nips.cc/paper\
90-
7062-a-unified-approach-to-interpreting-model-predictions
91-
92-
GradientShap approximates SHAP values by computing the expectations of
93-
gradients by randomly sampling from the distribution of baselines/references.
94-
It adds white noise to each input sample `n_samples` times, selects a
95-
random baseline from baselines' distribution and a random point along the
96-
path between the baseline and the input, and computes the gradient of outputs
97-
with respect to those selected random points. The final SHAP values represent
98-
the expected values of gradients * (inputs - baselines).
99-
100-
GradientShap makes an assumption that the input features are independent
101-
and that the explanation model is linear, meaning that the explanations
102-
are modeled through the additive composition of feature effects.
103-
Under those assumptions, SHAP value can be approximated as the expectation
104-
of gradients that are computed for randomly generated `n_samples` input
105-
samples after adding gaussian noise `n_samples` times to each input for
106-
different baselines/references.
107-
108-
In some sense it can be viewed as an approximation of integrated gradients
109-
by computing the expectations of gradients for different baselines.
110-
111-
Current implementation uses Smoothgrad from `NoiseTunnel` in order to
112-
randomly draw samples from the distribution of baselines, add noise to input
113-
samples and compute the expectation (smoothgrad).
114-
115117
Args:
116118
117119
inputs (tensor or tuple of tensors): Input for which SHAP attribution
@@ -222,19 +224,19 @@ def attribute(
222224
The deltas are ordered by each input example and `n_samples`
223225
noisy samples generated for it.
224226
225-
Examples::
226-
227-
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
228-
>>> # and returns an Nx10 tensor of class probabilities.
229-
>>> net = ImageClassifier()
230-
>>> gradient_shap = GradientShap(net)
231-
>>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
232-
>>> # choosing baselines randomly
233-
>>> baselines = torch.randn(20, 3, 32, 32)
234-
>>> # Computes gradient shap for the input
235-
>>> # Attribution size matches input size: 3x3x32x32
236-
>>> attribution = gradient_shap.attribute(input, baselines,
237-
target=5)
227+
Examples::
228+
229+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
230+
>>> # and returns an Nx10 tensor of class probabilities.
231+
>>> net = ImageClassifier()
232+
>>> gradient_shap = GradientShap(net)
233+
>>> input = torch.randn(3, 3, 32, 32, requires_grad=True)
234+
>>> # choosing baselines randomly
235+
>>> baselines = torch.randn(20, 3, 32, 32)
236+
>>> # Computes gradient shap for the input
237+
>>> # Attribution size matches input size: 3x3x32x32
238+
>>> attribution = gradient_shap.attribute(input, baselines,
239+
target=5)
238240
239241
"""
240242
# since `baselines` is a distribution, we can generate it using a function

captum/attr/_core/guided_backprop_deconvnet.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,21 @@ def _remove_hooks(self):
9494

9595

9696
class GuidedBackprop(ModifiedReluGradientAttribution):
97+
r"""
98+
Computes attribution using guided backpropagation. Guided backpropagation
99+
computes the gradient of the target output with respect to the input,
100+
but gradients of ReLU functions are overridden so that only
101+
non-negative gradients are backpropagated.
102+
103+
More details regarding the guided backpropagation algorithm can be found
104+
in the original paper here:
105+
https://arxiv.org/abs/1412.6806
106+
107+
Warning: Ensure that all ReLU operations in the forward function of the
108+
given model are performed using a module (nn.module.ReLU).
109+
If nn.functional.ReLU is used, gradients are not overridden appropriately.
110+
"""
111+
97112
def __init__(self, model: Module):
98113
r"""
99114
Args:
@@ -111,19 +126,6 @@ def attribute(
111126
additional_forward_args: Any = None,
112127
) -> TensorOrTupleOfTensorsGeneric:
113128
r""""
114-
Computes attribution using guided backpropagation. Guided backpropagation
115-
computes the gradient of the target output with respect to the input,
116-
but gradients of ReLU functions are overridden so that only
117-
non-negative gradients are backpropagated.
118-
119-
More details regarding the guided backpropagation algorithm can be found
120-
in the original paper here:
121-
https://arxiv.org/abs/1412.6806
122-
123-
Warning: Ensure that all ReLU operations in the forward function of the
124-
given model are performed using a module (nn.module.ReLU).
125-
If nn.functional.ReLU is used, gradients are not overridden appropriately.
126-
127129
Args:
128130
129131
inputs (tensor or tuple of tensors): Input for which
@@ -197,6 +199,24 @@ def attribute(
197199

198200

199201
class Deconvolution(ModifiedReluGradientAttribution):
202+
r"""
203+
Computes attribution using deconvolution. Deconvolution
204+
computes the gradient of the target output with respect to the input,
205+
but gradients of ReLU functions are overridden so that the gradient
206+
of the ReLU input is simply computed taking ReLU of the output gradient,
207+
essentially only propagating non-negative gradients (without
208+
dependence on the sign of the ReLU input).
209+
210+
More details regarding the deconvolution algorithm can be found
211+
in these papers:
212+
https://arxiv.org/abs/1311.2901
213+
https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8
214+
215+
Warning: Ensure that all ReLU operations in the forward function of the
216+
given model are performed using a module (nn.module.ReLU).
217+
If nn.functional.ReLU is used, gradients are not overridden appropriately.
218+
"""
219+
200220
def __init__(self, model: Module):
201221
r"""
202222
Args:
@@ -212,22 +232,6 @@ def attribute(
212232
additional_forward_args: Any = None,
213233
) -> TensorOrTupleOfTensorsGeneric:
214234
r""""
215-
Computes attribution using deconvolution. Deconvolution
216-
computes the gradient of the target output with respect to the input,
217-
but gradients of ReLU functions are overridden so that the gradient
218-
of the ReLU input is simply computed taking ReLU of the output gradient,
219-
essentially only propagating non-negative gradients (without
220-
dependence on the sign of the ReLU input).
221-
222-
More details regarding the deconvolution algorithm can be found
223-
in these papers:
224-
https://arxiv.org/abs/1311.2901
225-
https://link.springer.com/chapter/10.1007/978-3-319-46466-4_8
226-
227-
Warning: Ensure that all ReLU operations in the forward function of the
228-
given model are performed using a module (nn.module.ReLU).
229-
If nn.functional.ReLU is used, gradients are not overridden appropriately.
230-
231235
Args:
232236
233237
inputs (tensor or tuple of tensors): Input for which

0 commit comments

Comments
 (0)