Skip to content

Commit e7835de

Browse files
authored
Optim-wip: Add a ton of missing docs (#571)
* Add a ton of missing docs * Add docs for NumPy helpers * Add image docs * Add more docs * Improve docs * Add missing optional to doc * Update docs to reflect new changes * Some minor changes that I forget when I pulled from the optim-wip master branch. * Changes based on feedback * Fix Flake8 * Add missing 'optional's to docs
1 parent 46e16e4 commit e7835de

File tree

12 files changed

+601
-108
lines changed

12 files changed

+601
-108
lines changed

captum/optim/__init__.py

100755100644
File mode changed.

captum/optim/_core/optimization.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
) -> None:
4747
r"""
4848
Args:
49+
4950
model (nn.Module): The reference to PyTorch model instance.
5051
input_param (nn.Module, optional): A module that generates an input,
5152
consumed by the model.
@@ -71,6 +72,7 @@ def __init__(
7172

7273
def loss(self) -> torch.Tensor:
7374
r"""Compute loss value for current iteration.
75+
7476
Returns:
7577
*tensor* representing **loss**:
7678
- **loss** (*tensor*):
@@ -115,18 +117,26 @@ def optimize(
115117
lr: float = 0.025,
116118
) -> torch.Tensor:
117119
r"""Optimize input based on loss function and objectives.
120+
118121
Args:
122+
119123
stop_criteria (StopCriteria, optional): A function that is called
120124
every iteration and returns a bool that determines whether
121125
to stop the optimization.
122126
See captum.optim.typing.StopCriteria for details.
123127
optimizer (Optimizer, optional): An torch.optim.Optimizer used to
124128
optimize the input based on the loss function.
129+
loss_summarize_fn (Callable, optional): The function to use for summarizing
130+
tensor outputs from loss functions.
131+
Default: default_loss_summarize
132+
lr: (float, optional): If no optimizer is given, then lr is used as the
133+
learning rate for the Adam optimizer.
134+
Default: 0.025
135+
125136
Returns:
126-
*list* of *np.arrays* representing the **history**:
127-
- **history** (*list*):
128-
A list of loss values per iteration.
129-
Length of the list corresponds to the number of iterations
137+
history (torch.Tensor): A stack of loss values per iteration. The size
138+
of the dimension on which loss values are stacked corresponds to
139+
the number of iterations.
130140
"""
131141
stop_criteria = stop_criteria or n_steps(512)
132142
optimizer = optimizer or optim.Adam(self.parameters(), lr=lr)
@@ -150,10 +160,12 @@ def optimize(
150160

151161
def n_steps(n: int, show_progress: bool = True) -> StopCriteria:
152162
"""StopCriteria generator that uses number of steps as a stop criteria.
163+
153164
Args:
154165
n (int): Number of steps to run optimization.
155166
show_progress (bool, optional): Whether or not to show progress bar.
156167
Default: True
168+
157169
Returns:
158170
*StopCriteria* callable
159171
"""

captum/optim/_core/output_hook.py

100755100644
Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,37 @@
88
from captum.optim._utils.typing import ModuleOutputMapping, TupleOfTensorsOrTensorType
99

1010

11-
class ModuleReuseException(Exception):
12-
pass
13-
14-
1511
class ModuleOutputsHook:
1612
def __init__(self, target_modules: Iterable[nn.Module]) -> None:
13+
"""
14+
Args:
15+
16+
target_modules (Iterable of nn.Module): A list of nn.Module targets.
17+
"""
1718
self.outputs: ModuleOutputMapping = dict.fromkeys(target_modules, None)
1819
self.hooks = [
1920
module.register_forward_hook(self._forward_hook())
2021
for module in target_modules
2122
]
2223

2324
def _reset_outputs(self) -> None:
25+
"""
26+
Delete captured activations.
27+
"""
2428
self.outputs = dict.fromkeys(self.outputs.keys(), None)
2529

2630
@property
2731
def is_ready(self) -> bool:
2832
return all(value is not None for value in self.outputs.values())
2933

3034
def _forward_hook(self) -> Callable:
35+
"""
36+
Return the forward_hook function.
37+
38+
Returns:
39+
forward_hook (Callable): The forward_hook function.
40+
"""
41+
3142
def forward_hook(
3243
module: nn.Module, input: Tuple[torch.Tensor], output: torch.Tensor
3344
) -> None:
@@ -49,6 +60,12 @@ def forward_hook(
4960
return forward_hook
5061

5162
def consume_outputs(self) -> ModuleOutputMapping:
63+
"""
64+
Collect target activations and return them.
65+
66+
Returns:
67+
outputs (ModuleOutputMapping): The captured outputs.
68+
"""
5269
if not self.is_ready:
5370
warn(
5471
"Consume captured outputs, but not all requested target outputs "
@@ -63,11 +80,16 @@ def targets(self) -> Iterable[nn.Module]:
6380
return self.outputs.keys()
6481

6582
def remove_hooks(self) -> None:
83+
"""
84+
Remove hooks.
85+
"""
6686
for hook in self.hooks:
6787
hook.remove()
6888

6989
def __del__(self) -> None:
70-
# print(f"DEL HOOKS!: {list(self.outputs.keys())}")
90+
"""
91+
Ensure that using 'del' properly deletes hooks.
92+
"""
7193
self.remove_hooks()
7294

7395

@@ -77,16 +99,34 @@ class ActivationFetcher:
7799
"""
78100

79101
def __init__(self, model: nn.Module, targets: Iterable[nn.Module]) -> None:
102+
"""
103+
Args:
104+
105+
model (nn.Module): The reference to PyTorch model instance.
106+
targets (nn.Module or list of nn.Module): The target layers to
107+
collect activations from.
108+
"""
80109
super(ActivationFetcher, self).__init__()
81110
self.model = model
82111
self.layers = ModuleOutputsHook(targets)
83112

84113
def __call__(self, input_t: TupleOfTensorsOrTensorType) -> ModuleOutputMapping:
114+
"""
115+
Args:
116+
117+
input_t (tensor or tuple of tensors, optional): The input to use
118+
with the specified model.
119+
120+
Returns:
121+
activations_dict: An dict containing the collected activations. The keys
122+
for the returned dictionary are the target layers.
123+
"""
124+
85125
try:
86126
with warnings.catch_warnings():
87127
warnings.simplefilter("ignore")
88128
self.model(input_t)
89-
activations = self.layers.consume_outputs()
129+
activations_dict = self.layers.consume_outputs()
90130
finally:
91131
self.layers.remove_hooks()
92-
return activations
132+
return activations_dict

0 commit comments

Comments
 (0)