Skip to content

Commit b1a9830

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Add DataloaderAttribution (#1155)
Summary: Implement DataloaderAttribution, an attribution wrapper designed to wrap existing perturbation attr methods so that they can take `torch.utils.data.DataLoader` as the inputs. This enables attributions with respect to corpus-level metrics. `perturbation_per_pass` is not supported in this diff. Will add separately in the next diff Pull Request resolved: #1155 Reviewed By: vivekmig Differential Revision: D46322232 fbshipit-source-id: 8cb5beede80c8203f5d7f3c9edef2ca1c85a2547
1 parent 3aed726 commit b1a9830

File tree

2 files changed

+710
-0
lines changed

2 files changed

+710
-0
lines changed

captum/attr/_core/dataloader_attr.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
#!/usr/bin/env python3
2+
from collections import defaultdict
3+
from copy import copy
4+
from typing import Callable, Dict, List, Optional, Tuple, Union
5+
6+
import torch
7+
from captum._utils.common import (
8+
_format_baseline,
9+
_format_feature_mask,
10+
_format_output,
11+
_format_tensor_into_tuples,
12+
_get_max_feature_index,
13+
_run_forward,
14+
)
15+
from captum._utils.typing import BaselineType
16+
from captum.attr import FeatureAblation
17+
from captum.attr._utils.attribution import Attribution
18+
from torch import Tensor
19+
20+
21+
class InputRole:
22+
need_attr = 0
23+
need_forward = 1
24+
no_forward = 2
25+
26+
27+
SUPPORTED_METHODS = {FeatureAblation}
28+
29+
30+
# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
31+
def _concat_tensors(accum, cur_output, _):
32+
return cur_output if accum is None else torch.cat([accum, cur_output])
33+
34+
35+
def _convert_output_shape(
36+
unique_attr: Tensor,
37+
attr_inputs: Tuple[Tensor, ...],
38+
feature_mask: Tuple[Tensor, ...],
39+
) -> Tuple[Tensor, ...]:
40+
# unique_attr in shape(*output_dims, n_features)
41+
output_dims = unique_attr.shape[:-1]
42+
n_features = unique_attr.shape[-1]
43+
44+
attr = []
45+
46+
for inp, mask in zip(attr_inputs, feature_mask):
47+
# input in shape(batch_size, *inp_feature_dims)
48+
# attribute in shape(*output_dims, *inp_feature_dims)
49+
attr_shape = (*output_dims, *inp.shape[1:])
50+
51+
expanded_feature_indices = mask.expand(attr_shape)
52+
53+
if len(inp.shape) > 2:
54+
# exclude batch_size & last of actual value
55+
extra_inp_dims = list(inp.shape[1:-1])
56+
57+
# unsqueeze unqiue_attr to have same number of dims as inp
58+
# (*output_dims, 1..., 1, n_features)
59+
# then broadcast to (*output_dims, *inp.shape[1:-1], n_features)
60+
n_extra_dims = len(extra_inp_dims)
61+
unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_features)
62+
expanded_shape = (*output_dims, *extra_inp_dims, n_features)
63+
expanded_unqiue_attr = unique_attr.reshape(unsqueezed_shape).expand(
64+
expanded_shape
65+
)
66+
else:
67+
expanded_unqiue_attr = unique_attr
68+
69+
# gather from (*output_dims, *inp.shape[1:-1], n_features)
70+
inp_attr = torch.gather(expanded_unqiue_attr, -1, expanded_feature_indices)
71+
attr.append(inp_attr)
72+
73+
return tuple(attr)
74+
75+
76+
class DataloaderAttribution(Attribution):
77+
r"""
78+
Decorate a perturbation-based attribution algorthm to make it work with dataloaders.
79+
The decorated instance will calculate attribution in the
80+
same way as configured in the original attribution instance, but it will provide a
81+
new "attribute" function which accepts a pytorch "dataloader" instance as the input
82+
instead of a single batched "tensor" and supports customizing a "reduce" function to
83+
determine how the forward return of each iteration of the dataloader should be
84+
aggregated to single metric tensor to attribute. This would
85+
be specially useful to attribute against some corpus-wise metrics,
86+
e.g., Precision & Recall.
87+
"""
88+
89+
def __init__(self, attr_method: Attribution) -> None:
90+
r"""
91+
Args:
92+
attr_method (Attribution): An instance of any attribution algorithm
93+
of type `Attribution`. E.g. Integrated Gradients,
94+
Conductance or Saliency.
95+
"""
96+
97+
assert (
98+
type(attr_method) in SUPPORTED_METHODS
99+
), f"DataloaderAttribution does not support {type(attr_method)}"
100+
101+
super().__init__(attr_method.forward_func)
102+
103+
# shallow copy is enough to avoid modifying original instance
104+
self.attr_method = copy(attr_method)
105+
106+
self.attr_method.forward_func = self._forward_with_dataloader
107+
108+
def _forward_with_dataloader(
109+
self,
110+
perturbed_feature_indices,
111+
dataloader: torch.utils.data.DataLoader,
112+
input_roles: Tuple[int],
113+
baselines: Tuple[Union[int, float, Tensor], ...],
114+
feature_mask: Tuple[Tensor, ...],
115+
reduce: Callable,
116+
to_metric: Optional[Callable],
117+
perturbation_per_pass: int,
118+
show_progress: bool,
119+
feature_idx_to_mask_idx: Dict[int, List[int]],
120+
):
121+
# a set of input/mask indices that need perturbation
122+
perturbation_mask_indices = set()
123+
for i, v in enumerate(perturbed_feature_indices[0].tolist()):
124+
# value 0 means the feature has been perturbed
125+
if not v:
126+
perturbation_mask_indices |= set(feature_idx_to_mask_idx[i])
127+
128+
# create binary mask for inputs & set it to None if no perturbation is needed
129+
perturbation_mask = tuple(
130+
perturbed_feature_indices[0][mask_elem]
131+
if i in perturbation_mask_indices
132+
else None
133+
for i, mask_elem in enumerate(feature_mask)
134+
)
135+
136+
accum = None
137+
for inputs in dataloader:
138+
perturbed_inputs = []
139+
attr_inp_count = 0
140+
141+
for inp, role in zip(inputs, input_roles):
142+
if role != InputRole.need_attr:
143+
perturbed_inputs.append(inp)
144+
continue
145+
146+
pert_mask = perturbation_mask[attr_inp_count]
147+
148+
# no perturbation is needed for this input
149+
if pert_mask is None:
150+
perturbed_inputs.append(inp)
151+
else:
152+
baseline = baselines[attr_inp_count]
153+
154+
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
155+
perturbed_inputs.append(perturbed_inp)
156+
157+
attr_inp_count += 1
158+
159+
perturbed_inputs = tuple(perturbed_inputs)
160+
161+
# due to explicitly defined roles
162+
# we can keep inputs in their original order regardless of if they need attr
163+
# instead of using additional_forward_inputs to always appeend in the end
164+
forward_inputs = tuple(
165+
_
166+
for _, role in zip(perturbed_inputs, input_roles)
167+
if role != InputRole.no_forward
168+
)
169+
170+
output = _run_forward(
171+
self.forward_func,
172+
forward_inputs,
173+
)
174+
175+
accum = reduce(accum, output, perturbed_inputs)
176+
177+
if to_metric is not None:
178+
return to_metric(accum)
179+
180+
return accum
181+
182+
def attribute(
183+
self,
184+
dataloader: torch.utils.data.DataLoader,
185+
input_roles: Optional[Tuple[int, ...]] = None,
186+
baselines: BaselineType = None,
187+
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
188+
reduce: Optional[Callable] = None,
189+
to_metric: Optional[Callable] = None,
190+
perturbation_per_pass: int = -1,
191+
show_progress: bool = False,
192+
return_input_shape: bool = True,
193+
) -> Union[Tensor, Tuple[Tensor, ...]]:
194+
r"""
195+
Args:
196+
197+
dataloader (torch.Dataloader): the dataloader to attribute, which should
198+
return a tuple of consistant size for every iteration
199+
input_roles (tuple[int, ...], optional): a tuple of integers to define the
200+
role of each element returned from the dataloader. It should
201+
have the same size as the return of the dataloader.
202+
The available roles are:
203+
204+
0: the element is passed to forward_func and needs attribution.
205+
It must be a tensor.
206+
1: the element is excluded for forward_func. A typical example
207+
is the label.
208+
2: the element is passed to forward_func but does not need
209+
attribution. Like additional_forward_args
210+
211+
baselines (Union[Tensor, tuple[Tensor, ...]], optional): same as the
212+
baseline in attribute. The same baseline will be
213+
applied to the entire dataloader. The first dimension is
214+
assumed to be batch size and it must be 1. Baselines should only
215+
be specififed for the dataloader's returns that need
216+
attribution (role = 0)
217+
218+
feature_mask (Union[Tensor, tuple[Tensor, ...]], optional): same as the
219+
feature_mask in attribute. The same feature_mask will be
220+
applied to the entire dataloader. The first dimension is
221+
assumed to be batch size and it must be 1. Mask should only
222+
be specififed for the dataloader's returns that need
223+
attribution (role = 0)
224+
reduce (Callable, optional): a function to accumulate the forward output of
225+
each iteration of the dataloader. The function signature is:
226+
``reduce(accum, current_output, current_inputs) -> accum``,
227+
where:
228+
229+
accum (Any): accumulated states, can be any type
230+
current_output (Tensor): current output tensor from forward_func
231+
current_inputs (tuple[Any,...]): current inputs from dataloader
232+
233+
to_metric (Callable, optional): an optional function to further convert
234+
accumulated results through "reduce" after tranversing the whole
235+
dataloader to a single tensor of metrics to calculate
236+
attribution against. The function signature is:
237+
``to_metric(accum) -> metric``, where:
238+
239+
accum (Any): accumulated state from reduce function
240+
metric (Tensor): final result to be attributed, must be a Tensor
241+
242+
If None, will directly attribute w.r.t the reduced ``accum``
243+
perturbation_per_pass (int, optional
244+
concurrently in each traverse of the dataloader. The number of
245+
traverses is ceil(n_perturbations / perturbation_per_pass).
246+
The parameter offers a control of the trade-off between memory
247+
and efficiency. If the dataloader involves slow operations like
248+
remote request or file I/O, multiple traversals can be
249+
inefficient. Each perturbation needs to store its accumulated
250+
outputs of the reduce function until the end of the data
251+
traverse. If the value is -1, all perturbations are concurrent
252+
in a single traverse.
253+
return_input_shape (bool, optional): if True, returns the attribution
254+
following the input shapes given by the dataloader.
255+
Otherwise, returns a single tensor for the attributions of
256+
all the features, where the last dimension
257+
is the number of features.
258+
259+
Returns:
260+
**attributions** :
261+
- **attributions** (*Tensor* or *tuple[Tensor, ...]*):
262+
Attribution with respect to each input feature.
263+
if return_input_shape is True, attributions will be
264+
the same size as the given dataloader's returns that need
265+
attribution (role = 0), with each value
266+
providing the attribution of the corresponding input index.
267+
If a single tensor is provided as inputs, a single tensor is
268+
returned. If a tuple is provided for inputs, a tuple of
269+
corresponding sized tensors is returned.
270+
If return_input_shape is False, a single tensor is returned
271+
where each index of the last dimension represents a feature
272+
"""
273+
inputs = next(iter(dataloader))
274+
is_inputs_tuple = True
275+
276+
if type(inputs) is list:
277+
# support list as it is a common return type for dataloader in torch
278+
inputs = tuple(inputs)
279+
elif type(inputs) is not tuple:
280+
is_inputs_tuple = False
281+
inputs = _format_tensor_into_tuples(inputs)
282+
283+
if input_roles:
284+
assert len(input_roles) == len(inputs), (
285+
"input_roles must have the same size as the return of the dataloader,",
286+
f"length of input_roles is {len(input_roles)} ",
287+
f"whereas the length of dataloader return is {len(inputs)}",
288+
)
289+
290+
assert any(role == InputRole.need_attr for role in input_roles), (
291+
"input_roles must contain at least one element need attribution"
292+
f"({InputRole.need_attr}), received input_roles: {input_roles}"
293+
)
294+
else:
295+
# by default, assume every element in the dataloader needs attribution
296+
input_roles = tuple(InputRole.need_attr for _ in inputs)
297+
298+
attr_inputs = tuple(
299+
inp for role, inp in zip(input_roles, inputs) if role == InputRole.need_attr
300+
)
301+
302+
baselines = _format_baseline(baselines, attr_inputs)
303+
304+
assert len(attr_inputs) == len(baselines), (
305+
"Baselines must have the same size as the return of the dataloader ",
306+
"that need attribution",
307+
f"length of baseline is {len(baselines)} ",
308+
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
309+
)
310+
311+
for i, baseline in enumerate(baselines):
312+
if isinstance(baseline, Tensor):
313+
assert baseline.size(0) == 1, (
314+
"If the baseline is a tensor, "
315+
"its 1st dim of baseline must be 1 so it can be broadacasted to "
316+
"any batch of the dataloader:"
317+
f"baselines[{i}].shape = {baseline.shape}"
318+
)
319+
320+
feature_mask = _format_feature_mask(feature_mask, attr_inputs)
321+
322+
assert len(attr_inputs) == len(feature_mask), (
323+
"Feature mask must have the same size as the return of the dataloader ",
324+
"that need attribution",
325+
f"length of feature_mask is {len(feature_mask)} ",
326+
f'whereas the length of dataloader return with role "0" is {len(inputs)}',
327+
)
328+
329+
for i, each_mask in enumerate(feature_mask):
330+
assert each_mask.size(0) == 1, (
331+
"The 1st dim of feature_mask must be 1 so it can be broadcasted to "
332+
"any batch of the dataloader:"
333+
f"feature_mask[{i}].shape = {each_mask.shape}"
334+
)
335+
336+
# map to retrieve masks contain a given feature index
337+
feature_idx_to_mask_idx = defaultdict(list)
338+
for i, mask in enumerate(feature_mask):
339+
unqiue_feature_indices = torch.unique(mask).tolist()
340+
for feature_idx in unqiue_feature_indices:
341+
feature_idx_to_mask_idx[feature_idx].append(i)
342+
343+
max_feature_idx = _get_max_feature_index(feature_mask)
344+
n_features = max_feature_idx + 1
345+
346+
if reduce is None:
347+
reduce = _concat_tensors
348+
349+
# onehot tensor for feature indices
350+
feature_indices = torch.ones((1, n_features), device=attr_inputs[0].device)
351+
352+
# unique_attr in shape(*output_dims, n_features)
353+
unique_attr = self.attr_method.attribute(
354+
feature_indices,
355+
additional_forward_args=(
356+
dataloader,
357+
input_roles,
358+
baselines,
359+
feature_mask,
360+
reduce,
361+
to_metric,
362+
perturbation_per_pass,
363+
show_progress,
364+
feature_idx_to_mask_idx,
365+
),
366+
)
367+
368+
if not return_input_shape:
369+
return unique_attr
370+
else:
371+
attr = _convert_output_shape(
372+
unique_attr,
373+
attr_inputs,
374+
feature_mask,
375+
)
376+
377+
return _format_output(is_inputs_tuple, attr)

0 commit comments

Comments
 (0)