|
| 1 | +#!/usr/bin/env python3 |
| 2 | +from collections import defaultdict |
| 3 | +from copy import copy |
| 4 | +from typing import Callable, Dict, List, Optional, Tuple, Union |
| 5 | + |
| 6 | +import torch |
| 7 | +from captum._utils.common import ( |
| 8 | + _format_baseline, |
| 9 | + _format_feature_mask, |
| 10 | + _format_output, |
| 11 | + _format_tensor_into_tuples, |
| 12 | + _get_max_feature_index, |
| 13 | + _run_forward, |
| 14 | +) |
| 15 | +from captum._utils.typing import BaselineType |
| 16 | +from captum.attr import FeatureAblation |
| 17 | +from captum.attr._utils.attribution import Attribution |
| 18 | +from torch import Tensor |
| 19 | + |
| 20 | + |
| 21 | +class InputRole: |
| 22 | + need_attr = 0 |
| 23 | + need_forward = 1 |
| 24 | + no_forward = 2 |
| 25 | + |
| 26 | + |
| 27 | +SUPPORTED_METHODS = {FeatureAblation} |
| 28 | + |
| 29 | + |
| 30 | +# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension |
| 31 | +def _concat_tensors(accum, cur_output, _): |
| 32 | + return cur_output if accum is None else torch.cat([accum, cur_output]) |
| 33 | + |
| 34 | + |
| 35 | +def _convert_output_shape( |
| 36 | + unique_attr: Tensor, |
| 37 | + attr_inputs: Tuple[Tensor, ...], |
| 38 | + feature_mask: Tuple[Tensor, ...], |
| 39 | +) -> Tuple[Tensor, ...]: |
| 40 | + # unique_attr in shape(*output_dims, n_features) |
| 41 | + output_dims = unique_attr.shape[:-1] |
| 42 | + n_features = unique_attr.shape[-1] |
| 43 | + |
| 44 | + attr = [] |
| 45 | + |
| 46 | + for inp, mask in zip(attr_inputs, feature_mask): |
| 47 | + # input in shape(batch_size, *inp_feature_dims) |
| 48 | + # attribute in shape(*output_dims, *inp_feature_dims) |
| 49 | + attr_shape = (*output_dims, *inp.shape[1:]) |
| 50 | + |
| 51 | + expanded_feature_indices = mask.expand(attr_shape) |
| 52 | + |
| 53 | + if len(inp.shape) > 2: |
| 54 | + # exclude batch_size & last of actual value |
| 55 | + extra_inp_dims = list(inp.shape[1:-1]) |
| 56 | + |
| 57 | + # unsqueeze unqiue_attr to have same number of dims as inp |
| 58 | + # (*output_dims, 1..., 1, n_features) |
| 59 | + # then broadcast to (*output_dims, *inp.shape[1:-1], n_features) |
| 60 | + n_extra_dims = len(extra_inp_dims) |
| 61 | + unsqueezed_shape = (*output_dims, *(1,) * n_extra_dims, n_features) |
| 62 | + expanded_shape = (*output_dims, *extra_inp_dims, n_features) |
| 63 | + expanded_unqiue_attr = unique_attr.reshape(unsqueezed_shape).expand( |
| 64 | + expanded_shape |
| 65 | + ) |
| 66 | + else: |
| 67 | + expanded_unqiue_attr = unique_attr |
| 68 | + |
| 69 | + # gather from (*output_dims, *inp.shape[1:-1], n_features) |
| 70 | + inp_attr = torch.gather(expanded_unqiue_attr, -1, expanded_feature_indices) |
| 71 | + attr.append(inp_attr) |
| 72 | + |
| 73 | + return tuple(attr) |
| 74 | + |
| 75 | + |
| 76 | +class DataloaderAttribution(Attribution): |
| 77 | + r""" |
| 78 | + Decorate a perturbation-based attribution algorthm to make it work with dataloaders. |
| 79 | + The decorated instance will calculate attribution in the |
| 80 | + same way as configured in the original attribution instance, but it will provide a |
| 81 | + new "attribute" function which accepts a pytorch "dataloader" instance as the input |
| 82 | + instead of a single batched "tensor" and supports customizing a "reduce" function to |
| 83 | + determine how the forward return of each iteration of the dataloader should be |
| 84 | + aggregated to single metric tensor to attribute. This would |
| 85 | + be specially useful to attribute against some corpus-wise metrics, |
| 86 | + e.g., Precision & Recall. |
| 87 | + """ |
| 88 | + |
| 89 | + def __init__(self, attr_method: Attribution) -> None: |
| 90 | + r""" |
| 91 | + Args: |
| 92 | + attr_method (Attribution): An instance of any attribution algorithm |
| 93 | + of type `Attribution`. E.g. Integrated Gradients, |
| 94 | + Conductance or Saliency. |
| 95 | + """ |
| 96 | + |
| 97 | + assert ( |
| 98 | + type(attr_method) in SUPPORTED_METHODS |
| 99 | + ), f"DataloaderAttribution does not support {type(attr_method)}" |
| 100 | + |
| 101 | + super().__init__(attr_method.forward_func) |
| 102 | + |
| 103 | + # shallow copy is enough to avoid modifying original instance |
| 104 | + self.attr_method = copy(attr_method) |
| 105 | + |
| 106 | + self.attr_method.forward_func = self._forward_with_dataloader |
| 107 | + |
| 108 | + def _forward_with_dataloader( |
| 109 | + self, |
| 110 | + perturbed_feature_indices, |
| 111 | + dataloader: torch.utils.data.DataLoader, |
| 112 | + input_roles: Tuple[int], |
| 113 | + baselines: Tuple[Union[int, float, Tensor], ...], |
| 114 | + feature_mask: Tuple[Tensor, ...], |
| 115 | + reduce: Callable, |
| 116 | + to_metric: Optional[Callable], |
| 117 | + perturbation_per_pass: int, |
| 118 | + show_progress: bool, |
| 119 | + feature_idx_to_mask_idx: Dict[int, List[int]], |
| 120 | + ): |
| 121 | + # a set of input/mask indices that need perturbation |
| 122 | + perturbation_mask_indices = set() |
| 123 | + for i, v in enumerate(perturbed_feature_indices[0].tolist()): |
| 124 | + # value 0 means the feature has been perturbed |
| 125 | + if not v: |
| 126 | + perturbation_mask_indices |= set(feature_idx_to_mask_idx[i]) |
| 127 | + |
| 128 | + # create binary mask for inputs & set it to None if no perturbation is needed |
| 129 | + perturbation_mask = tuple( |
| 130 | + perturbed_feature_indices[0][mask_elem] |
| 131 | + if i in perturbation_mask_indices |
| 132 | + else None |
| 133 | + for i, mask_elem in enumerate(feature_mask) |
| 134 | + ) |
| 135 | + |
| 136 | + accum = None |
| 137 | + for inputs in dataloader: |
| 138 | + perturbed_inputs = [] |
| 139 | + attr_inp_count = 0 |
| 140 | + |
| 141 | + for inp, role in zip(inputs, input_roles): |
| 142 | + if role != InputRole.need_attr: |
| 143 | + perturbed_inputs.append(inp) |
| 144 | + continue |
| 145 | + |
| 146 | + pert_mask = perturbation_mask[attr_inp_count] |
| 147 | + |
| 148 | + # no perturbation is needed for this input |
| 149 | + if pert_mask is None: |
| 150 | + perturbed_inputs.append(inp) |
| 151 | + else: |
| 152 | + baseline = baselines[attr_inp_count] |
| 153 | + |
| 154 | + perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask) |
| 155 | + perturbed_inputs.append(perturbed_inp) |
| 156 | + |
| 157 | + attr_inp_count += 1 |
| 158 | + |
| 159 | + perturbed_inputs = tuple(perturbed_inputs) |
| 160 | + |
| 161 | + # due to explicitly defined roles |
| 162 | + # we can keep inputs in their original order regardless of if they need attr |
| 163 | + # instead of using additional_forward_inputs to always appeend in the end |
| 164 | + forward_inputs = tuple( |
| 165 | + _ |
| 166 | + for _, role in zip(perturbed_inputs, input_roles) |
| 167 | + if role != InputRole.no_forward |
| 168 | + ) |
| 169 | + |
| 170 | + output = _run_forward( |
| 171 | + self.forward_func, |
| 172 | + forward_inputs, |
| 173 | + ) |
| 174 | + |
| 175 | + accum = reduce(accum, output, perturbed_inputs) |
| 176 | + |
| 177 | + if to_metric is not None: |
| 178 | + return to_metric(accum) |
| 179 | + |
| 180 | + return accum |
| 181 | + |
| 182 | + def attribute( |
| 183 | + self, |
| 184 | + dataloader: torch.utils.data.DataLoader, |
| 185 | + input_roles: Optional[Tuple[int, ...]] = None, |
| 186 | + baselines: BaselineType = None, |
| 187 | + feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, |
| 188 | + reduce: Optional[Callable] = None, |
| 189 | + to_metric: Optional[Callable] = None, |
| 190 | + perturbation_per_pass: int = -1, |
| 191 | + show_progress: bool = False, |
| 192 | + return_input_shape: bool = True, |
| 193 | + ) -> Union[Tensor, Tuple[Tensor, ...]]: |
| 194 | + r""" |
| 195 | + Args: |
| 196 | +
|
| 197 | + dataloader (torch.Dataloader): the dataloader to attribute, which should |
| 198 | + return a tuple of consistant size for every iteration |
| 199 | + input_roles (tuple[int, ...], optional): a tuple of integers to define the |
| 200 | + role of each element returned from the dataloader. It should |
| 201 | + have the same size as the return of the dataloader. |
| 202 | + The available roles are: |
| 203 | +
|
| 204 | + 0: the element is passed to forward_func and needs attribution. |
| 205 | + It must be a tensor. |
| 206 | + 1: the element is excluded for forward_func. A typical example |
| 207 | + is the label. |
| 208 | + 2: the element is passed to forward_func but does not need |
| 209 | + attribution. Like additional_forward_args |
| 210 | +
|
| 211 | + baselines (Union[Tensor, tuple[Tensor, ...]], optional): same as the |
| 212 | + baseline in attribute. The same baseline will be |
| 213 | + applied to the entire dataloader. The first dimension is |
| 214 | + assumed to be batch size and it must be 1. Baselines should only |
| 215 | + be specififed for the dataloader's returns that need |
| 216 | + attribution (role = 0) |
| 217 | +
|
| 218 | + feature_mask (Union[Tensor, tuple[Tensor, ...]], optional): same as the |
| 219 | + feature_mask in attribute. The same feature_mask will be |
| 220 | + applied to the entire dataloader. The first dimension is |
| 221 | + assumed to be batch size and it must be 1. Mask should only |
| 222 | + be specififed for the dataloader's returns that need |
| 223 | + attribution (role = 0) |
| 224 | + reduce (Callable, optional): a function to accumulate the forward output of |
| 225 | + each iteration of the dataloader. The function signature is: |
| 226 | + ``reduce(accum, current_output, current_inputs) -> accum``, |
| 227 | + where: |
| 228 | +
|
| 229 | + accum (Any): accumulated states, can be any type |
| 230 | + current_output (Tensor): current output tensor from forward_func |
| 231 | + current_inputs (tuple[Any,...]): current inputs from dataloader |
| 232 | +
|
| 233 | + to_metric (Callable, optional): an optional function to further convert |
| 234 | + accumulated results through "reduce" after tranversing the whole |
| 235 | + dataloader to a single tensor of metrics to calculate |
| 236 | + attribution against. The function signature is: |
| 237 | + ``to_metric(accum) -> metric``, where: |
| 238 | +
|
| 239 | + accum (Any): accumulated state from reduce function |
| 240 | + metric (Tensor): final result to be attributed, must be a Tensor |
| 241 | +
|
| 242 | + If None, will directly attribute w.r.t the reduced ``accum`` |
| 243 | + perturbation_per_pass (int, optional |
| 244 | + concurrently in each traverse of the dataloader. The number of |
| 245 | + traverses is ceil(n_perturbations / perturbation_per_pass). |
| 246 | + The parameter offers a control of the trade-off between memory |
| 247 | + and efficiency. If the dataloader involves slow operations like |
| 248 | + remote request or file I/O, multiple traversals can be |
| 249 | + inefficient. Each perturbation needs to store its accumulated |
| 250 | + outputs of the reduce function until the end of the data |
| 251 | + traverse. If the value is -1, all perturbations are concurrent |
| 252 | + in a single traverse. |
| 253 | + return_input_shape (bool, optional): if True, returns the attribution |
| 254 | + following the input shapes given by the dataloader. |
| 255 | + Otherwise, returns a single tensor for the attributions of |
| 256 | + all the features, where the last dimension |
| 257 | + is the number of features. |
| 258 | +
|
| 259 | + Returns: |
| 260 | + **attributions** : |
| 261 | + - **attributions** (*Tensor* or *tuple[Tensor, ...]*): |
| 262 | + Attribution with respect to each input feature. |
| 263 | + if return_input_shape is True, attributions will be |
| 264 | + the same size as the given dataloader's returns that need |
| 265 | + attribution (role = 0), with each value |
| 266 | + providing the attribution of the corresponding input index. |
| 267 | + If a single tensor is provided as inputs, a single tensor is |
| 268 | + returned. If a tuple is provided for inputs, a tuple of |
| 269 | + corresponding sized tensors is returned. |
| 270 | + If return_input_shape is False, a single tensor is returned |
| 271 | + where each index of the last dimension represents a feature |
| 272 | + """ |
| 273 | + inputs = next(iter(dataloader)) |
| 274 | + is_inputs_tuple = True |
| 275 | + |
| 276 | + if type(inputs) is list: |
| 277 | + # support list as it is a common return type for dataloader in torch |
| 278 | + inputs = tuple(inputs) |
| 279 | + elif type(inputs) is not tuple: |
| 280 | + is_inputs_tuple = False |
| 281 | + inputs = _format_tensor_into_tuples(inputs) |
| 282 | + |
| 283 | + if input_roles: |
| 284 | + assert len(input_roles) == len(inputs), ( |
| 285 | + "input_roles must have the same size as the return of the dataloader,", |
| 286 | + f"length of input_roles is {len(input_roles)} ", |
| 287 | + f"whereas the length of dataloader return is {len(inputs)}", |
| 288 | + ) |
| 289 | + |
| 290 | + assert any(role == InputRole.need_attr for role in input_roles), ( |
| 291 | + "input_roles must contain at least one element need attribution" |
| 292 | + f"({InputRole.need_attr}), received input_roles: {input_roles}" |
| 293 | + ) |
| 294 | + else: |
| 295 | + # by default, assume every element in the dataloader needs attribution |
| 296 | + input_roles = tuple(InputRole.need_attr for _ in inputs) |
| 297 | + |
| 298 | + attr_inputs = tuple( |
| 299 | + inp for role, inp in zip(input_roles, inputs) if role == InputRole.need_attr |
| 300 | + ) |
| 301 | + |
| 302 | + baselines = _format_baseline(baselines, attr_inputs) |
| 303 | + |
| 304 | + assert len(attr_inputs) == len(baselines), ( |
| 305 | + "Baselines must have the same size as the return of the dataloader ", |
| 306 | + "that need attribution", |
| 307 | + f"length of baseline is {len(baselines)} ", |
| 308 | + f'whereas the length of dataloader return with role "0" is {len(inputs)}', |
| 309 | + ) |
| 310 | + |
| 311 | + for i, baseline in enumerate(baselines): |
| 312 | + if isinstance(baseline, Tensor): |
| 313 | + assert baseline.size(0) == 1, ( |
| 314 | + "If the baseline is a tensor, " |
| 315 | + "its 1st dim of baseline must be 1 so it can be broadacasted to " |
| 316 | + "any batch of the dataloader:" |
| 317 | + f"baselines[{i}].shape = {baseline.shape}" |
| 318 | + ) |
| 319 | + |
| 320 | + feature_mask = _format_feature_mask(feature_mask, attr_inputs) |
| 321 | + |
| 322 | + assert len(attr_inputs) == len(feature_mask), ( |
| 323 | + "Feature mask must have the same size as the return of the dataloader ", |
| 324 | + "that need attribution", |
| 325 | + f"length of feature_mask is {len(feature_mask)} ", |
| 326 | + f'whereas the length of dataloader return with role "0" is {len(inputs)}', |
| 327 | + ) |
| 328 | + |
| 329 | + for i, each_mask in enumerate(feature_mask): |
| 330 | + assert each_mask.size(0) == 1, ( |
| 331 | + "The 1st dim of feature_mask must be 1 so it can be broadcasted to " |
| 332 | + "any batch of the dataloader:" |
| 333 | + f"feature_mask[{i}].shape = {each_mask.shape}" |
| 334 | + ) |
| 335 | + |
| 336 | + # map to retrieve masks contain a given feature index |
| 337 | + feature_idx_to_mask_idx = defaultdict(list) |
| 338 | + for i, mask in enumerate(feature_mask): |
| 339 | + unqiue_feature_indices = torch.unique(mask).tolist() |
| 340 | + for feature_idx in unqiue_feature_indices: |
| 341 | + feature_idx_to_mask_idx[feature_idx].append(i) |
| 342 | + |
| 343 | + max_feature_idx = _get_max_feature_index(feature_mask) |
| 344 | + n_features = max_feature_idx + 1 |
| 345 | + |
| 346 | + if reduce is None: |
| 347 | + reduce = _concat_tensors |
| 348 | + |
| 349 | + # onehot tensor for feature indices |
| 350 | + feature_indices = torch.ones((1, n_features), device=attr_inputs[0].device) |
| 351 | + |
| 352 | + # unique_attr in shape(*output_dims, n_features) |
| 353 | + unique_attr = self.attr_method.attribute( |
| 354 | + feature_indices, |
| 355 | + additional_forward_args=( |
| 356 | + dataloader, |
| 357 | + input_roles, |
| 358 | + baselines, |
| 359 | + feature_mask, |
| 360 | + reduce, |
| 361 | + to_metric, |
| 362 | + perturbation_per_pass, |
| 363 | + show_progress, |
| 364 | + feature_idx_to_mask_idx, |
| 365 | + ), |
| 366 | + ) |
| 367 | + |
| 368 | + if not return_input_shape: |
| 369 | + return unique_attr |
| 370 | + else: |
| 371 | + attr = _convert_output_shape( |
| 372 | + unique_attr, |
| 373 | + attr_inputs, |
| 374 | + feature_mask, |
| 375 | + ) |
| 376 | + |
| 377 | + return _format_output(is_inputs_tuple, attr) |
0 commit comments