Skip to content

Commit 508fdc8

Browse files
Sharon Tanfacebook-github-bot
authored andcommitted
Add support for list of tensors output in Layer Gradient Attributor (#1629)
Summary: Pull Request resolved: #1629 As part of increasing % of layers supported, we want to extend Layer Gradient Attributor to support List of Tensor with grad outputs. This is fairly safe since we already support Tuple of Tensor with grad outputs, and iteration across both types behaves similarly. Reviewed By: craymichael Differential Revision: D79134965
1 parent 3ec6da4 commit 508fdc8

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

captum/_utils/common.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -759,10 +759,10 @@ def _reduce_list(
759759
Applies reduction function to given list. If each element in the list is
760760
a Tensor, applies reduction function to all elements of the list, and returns
761761
the output Tensor / value. If each element is a boolean, apply any method (or).
762-
If each element is a tuple, applies reduction
763-
function to corresponding elements of each tuple in the list, and returns
762+
If each element is a tuple/list, applies reduction
763+
function to corresponding elements of each tuple/list in the list, and returns
764764
tuple of reduction function outputs with length matching the length of tuple
765-
val_list[0]. It is assumed that all tuples in the list have the same length
765+
val_list[0]. It is assumed that all tuples/lists in the list have the same length
766766
and red_func can be applied to all elements in each corresponding position.
767767
"""
768768
assert len(val_list) > 0, "Cannot reduce empty list!"
@@ -774,7 +774,7 @@ def _reduce_list(
774774
elif isinstance(val_list[0], bool):
775775
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
776776
return any(val_list)
777-
elif isinstance(val_list[0], tuple):
777+
elif isinstance(val_list[0], (tuple, list)):
778778
final_out = []
779779
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
780780
# but got `TupleOrTensorOrBoolGeneric`.
@@ -786,7 +786,7 @@ def _reduce_list(
786786
else:
787787
raise AssertionError(
788788
"Elements to be reduced can only be"
789-
"either Tensors or tuples containing Tensors."
789+
"either Tensors or tuples/lists containing Tensors."
790790
)
791791
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`.
792792
return tuple(final_out)

captum/_utils/gradient.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def apply_gradient_requirements(
4646
a tensor originally required grad is returned.
4747
"""
4848
assert isinstance(
49-
inputs, tuple
50-
), "Inputs should be wrapped in a tuple prior to preparing for gradients"
49+
inputs, (tuple, list)
50+
), "Inputs should be wrapped in a tuple or list prior to preparing for gradients"
5151
grad_required = []
5252
for index, input in enumerate(inputs):
5353
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
@@ -298,9 +298,9 @@ def hook_wrapper(original_module):
298298
# pyre-fixme[2]: Parameter must be annotated.
299299
def forward_hook(module, inp, out=None):
300300
eval_tsrs = inp if attribute_to_layer_input else out
301-
is_eval_tuple = isinstance(eval_tsrs, tuple)
301+
is_eval_tuple_or_list = isinstance(eval_tsrs, (tuple, list))
302302

303-
if not is_eval_tuple:
303+
if not is_eval_tuple_or_list:
304304
eval_tsrs = (eval_tsrs,)
305305
if require_layer_grads:
306306
apply_gradient_requirements(eval_tsrs, warn=False)
@@ -310,11 +310,16 @@ def forward_hook(module, inp, out=None):
310310
# otherwise `backward()` on the last output layer won't execute.
311311
if forward_hook_with_return:
312312
saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs
313-
eval_tsrs_to_return = tuple(
314-
eval_tsr.clone() for eval_tsr in eval_tsrs
315-
)
316-
if not is_eval_tuple:
317-
eval_tsrs_to_return = eval_tsrs_to_return[0]
313+
if not is_eval_tuple_or_list:
314+
eval_tsrs_to_return = eval_tsrs[0].clone()
315+
elif isinstance(eval_tsrs, list):
316+
eval_tsrs_to_return = [
317+
eval_tsr.clone() for eval_tsr in eval_tsrs
318+
]
319+
else:
320+
eval_tsrs_to_return = tuple(
321+
eval_tsr.clone() for eval_tsr in eval_tsrs
322+
)
318323
return eval_tsrs_to_return
319324
else:
320325
saved_layer[original_module][eval_tsrs[0].device] = tuple(

captum/testing/helpers/basic_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,8 @@ def __init__(
432432
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
433433
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))
434434

435+
self.list_output_layer = PassThroughLayerOutput()
436+
435437
self.int_layer = PassThroughLayerOutput() # sample layer with an int output
436438

437439
@no_type_check
@@ -452,11 +454,13 @@ def forward(
452454

453455
relu_out = self.relu(lin1_out)
454456
lin2_out = self.linear2(relu_out)
457+
list_out = self.list_output_layer([nn.Linear(2, 2)(lin2_out) for _ in range(2)])
458+
resized_list_out = torch.cat(list_out, dim=1)
455459

456460
lin3_out = self.linear3(lin1_out_alt)
457461
int_output = self.int_layer(lin3_out.to(torch.int64))
458462

459-
output_tensors = torch.cat((lin2_out, int_output), dim=1)
463+
output_tensors = torch.cat((resized_list_out, int_output), dim=1)
460464

461465
return (
462466
output_tensors

0 commit comments

Comments
 (0)