Skip to content

Add support for list of tensors output in Layer Gradient Attributor #1629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,10 +759,10 @@ def _reduce_list(
Applies reduction function to given list. If each element in the list is
a Tensor, applies reduction function to all elements of the list, and returns
the output Tensor / value. If each element is a boolean, apply any method (or).
If each element is a tuple, applies reduction
function to corresponding elements of each tuple in the list, and returns
If each element is a tuple/list, applies reduction
function to corresponding elements of each tuple/list in the list, and returns
tuple of reduction function outputs with length matching the length of tuple
val_list[0]. It is assumed that all tuples in the list have the same length
val_list[0]. It is assumed that all tuples/lists in the list have the same length
and red_func can be applied to all elements in each corresponding position.
"""
assert len(val_list) > 0, "Cannot reduce empty list!"
Expand All @@ -774,7 +774,7 @@ def _reduce_list(
elif isinstance(val_list[0], bool):
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `bool`.
return any(val_list)
elif isinstance(val_list[0], tuple):
elif isinstance(val_list[0], (tuple, list)):
final_out = []
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
# but got `TupleOrTensorOrBoolGeneric`.
Expand All @@ -786,7 +786,7 @@ def _reduce_list(
else:
raise AssertionError(
"Elements to be reduced can only be"
"either Tensors or tuples containing Tensors."
"either Tensors or tuples/lists containing Tensors."
)
# pyre-fixme[7]: Expected `TupleOrTensorOrBoolGeneric` but got `Tuple[Any, ...]`.
return tuple(final_out)
Expand Down
23 changes: 14 additions & 9 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def apply_gradient_requirements(
a tensor originally required grad is returned.
"""
assert isinstance(
inputs, tuple
), "Inputs should be wrapped in a tuple prior to preparing for gradients"
inputs, (tuple, list)
), "Inputs should be wrapped in a tuple or list prior to preparing for gradients"
grad_required = []
for index, input in enumerate(inputs):
assert isinstance(input, torch.Tensor), "Given input is not a torch.Tensor"
Expand Down Expand Up @@ -298,9 +298,9 @@ def hook_wrapper(original_module):
# pyre-fixme[2]: Parameter must be annotated.
def forward_hook(module, inp, out=None):
eval_tsrs = inp if attribute_to_layer_input else out
is_eval_tuple = isinstance(eval_tsrs, tuple)
is_eval_tuple_or_list = isinstance(eval_tsrs, (tuple, list))

if not is_eval_tuple:
if not is_eval_tuple_or_list:
eval_tsrs = (eval_tsrs,)
if require_layer_grads:
apply_gradient_requirements(eval_tsrs, warn=False)
Expand All @@ -310,11 +310,16 @@ def forward_hook(module, inp, out=None):
# otherwise `backward()` on the last output layer won't execute.
if forward_hook_with_return:
saved_layer[original_module][eval_tsrs[0].device] = eval_tsrs
eval_tsrs_to_return = tuple(
eval_tsr.clone() for eval_tsr in eval_tsrs
)
if not is_eval_tuple:
eval_tsrs_to_return = eval_tsrs_to_return[0]
if not is_eval_tuple_or_list:
eval_tsrs_to_return = eval_tsrs[0].clone()
elif isinstance(eval_tsrs, list):
eval_tsrs_to_return = [
eval_tsr.clone() for eval_tsr in eval_tsrs
]
else:
eval_tsrs_to_return = tuple(
eval_tsr.clone() for eval_tsr in eval_tsrs
)
return eval_tsrs_to_return
else:
saved_layer[original_module][eval_tsrs[0].device] = tuple(
Expand Down
17 changes: 16 additions & 1 deletion captum/testing/helpers/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ def __init__(
self.linear3.weight = nn.Parameter(torch.ones(2, 4))
self.linear3.bias = nn.Parameter(torch.tensor([-1.0, 1.0]))

self.list_output_layer = ListOutputLayer(2, [2, 2])

self.int_layer = PassThroughLayerOutput() # sample layer with an int output

@no_type_check
Expand All @@ -452,11 +454,13 @@ def forward(

relu_out = self.relu(lin1_out)
lin2_out = self.linear2(relu_out)
list_out = self.list_output_layer(lin2_out)
resized_list_out = torch.cat(list_out, dim=1)

lin3_out = self.linear3(lin1_out_alt)
int_output = self.int_layer(lin3_out.to(torch.int64))

output_tensors = torch.cat((lin2_out, int_output), dim=1)
output_tensors = torch.cat((resized_list_out, int_output), dim=1)

return (
output_tensors
Expand All @@ -470,6 +474,17 @@ def forward(
# where an output accessor is required


class ListOutputLayer(nn.Module):
def __init__(self, input_size: int, output_sizes: List[int]) -> None:
super().__init__()
self.linears = nn.ModuleList(
[nn.Linear(input_size, size) for size in output_sizes]
)

def forward(self, x: Tensor) -> List[Tensor]:
return [linear(x) for linear in self.linears]


class MultiRelu(nn.Module):
def __init__(self, inplace: bool = False) -> None:
super().__init__()
Expand Down
Loading