Skip to content

Commit 82ab88d

Browse files
authored
Optim-wip: Add main Activation Atlas tutorial & functions (#782)
* Add main Activation Atlas classes & functions * Add main Activation Atlas tutorial notebook * Remove unsued import * Changes based on feedback * Underscore rotation class functions * Fix & improve documentation. * Revert weight heatmap variable name changes. * Correct rotation transform bug. * Revert ufmt change as it causes isort to fail * Improve documentation of `AngledNeuronDirection` & atlas functions * Improve atlas related documentation * Move atlas.py to _utils/image & improve atlas docs * Add sum_loss_list() function & correct target type hints * Add `sum_loss_list()` to `optim/_core/loss.py` with tests. * Replace `sum()` with `opt.loss.sum_loss_list()` in ActivationAtlas tutorial notebook. * Correct `target` type hints for `Loss`, `BaseLoss`, and `CompositeLoss`. The correct type hint should be `Union[nn.Module, List[nn.Module]]` as that is what the code supports and uses. * RandomRotation JIT support & other improvements * Better way to handle torch version check with JIT * The `torch.jit.is_scripting()` function isn't supported by earlier versions of PyTorch. So I've come up with a better solution that will still support earlier versions of PyTorch. * Exposed `align_corners` parameter to `RandomRotation` class initialization. * Improved `RandomRotation` class documentation. * Use better scale type hint in RandomRotation init function * Add torch.distributions support to RandomRotation * Ludwig wanted this according to his original PR. * Also fix mypy bug. * Add assert & more tests for RandomRotation * Added `torch.distribution` assert check and more extensive testing for the RandomRotation transform. * Adding SkipTest to RandomRotation reflection * Fix formatting error * Changes to main atlas tutorial notebook based on feedback * `extract_grid_vectors` -> `compute_avg_cell_samples` * Improved documentation, and added better descriptions to the main atlas tutorial notebook. * Remove unused type hint * Improve whitening description in main activation atlas tutorial * Spelling & grammar fixes * Improve `calc_grid_indices` documentation
1 parent 6e7f0bd commit 82ab88d

File tree

11 files changed

+2671
-45
lines changed

11 files changed

+2671
-45
lines changed

captum/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from captum.optim._param.image import images, transforms # noqa: F401
77
from captum.optim._param.image.images import ImageTensor # noqa: F401
88
from captum.optim._utils import circuits, reducer # noqa: F401
9+
from captum.optim._utils.image import atlas # noqa: F401
910
from captum.optim._utils.image.common import ( # noqa: F401
1011
nchannels_to_rgb,
1112
save_tensor_as_image,
@@ -23,6 +24,7 @@
2324
"circuits",
2425
"models",
2526
"reducer",
27+
"atlas",
2628
"nchannels_to_rgb",
2729
"save_tensor_as_image",
2830
"show",

captum/optim/_core/loss.py

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import operator
33
from abc import ABC, abstractmethod, abstractproperty
4-
from typing import Any, Callable, Optional, Tuple, Union
4+
from typing import Any, Callable, List, Optional, Tuple, Union
55

66
import torch
77
import torch.nn as nn
@@ -27,7 +27,7 @@ def __init__(self) -> None:
2727
super(Loss, self).__init__()
2828

2929
@abstractproperty
30-
def target(self) -> nn.Module:
30+
def target(self) -> Union[nn.Module, List[nn.Module]]:
3131
pass
3232

3333
@abstractmethod
@@ -140,7 +140,9 @@ def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
140140

141141
class BaseLoss(Loss):
142142
def __init__(
143-
self, target: nn.Module = [], batch_index: Optional[int] = None
143+
self,
144+
target: Union[nn.Module, List[nn.Module]] = [],
145+
batch_index: Optional[int] = None,
144146
) -> None:
145147
super(BaseLoss, self).__init__()
146148
self._target = target
@@ -150,7 +152,7 @@ def __init__(
150152
self._batch_index = (batch_index, batch_index + 1)
151153

152154
@property
153-
def target(self) -> nn.Module:
155+
def target(self) -> Union[nn.Module, List[nn.Module]]:
154156
return self._target
155157

156158
@property
@@ -160,7 +162,10 @@ def batch_index(self) -> Tuple:
160162

161163
class CompositeLoss(BaseLoss):
162164
def __init__(
163-
self, loss_fn: Callable, name: str = "", target: nn.Module = []
165+
self,
166+
loss_fn: Callable,
167+
name: str = "",
168+
target: Union[nn.Module, List[nn.Module]] = [],
164169
) -> None:
165170
super(CompositeLoss, self).__init__(target)
166171
self.__name__ = name
@@ -499,6 +504,94 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
499504
return _dot_cossim(self.direction, activations, cossim_pow=self.cossim_pow)
500505

501506

507+
@loss_wrapper
508+
class AngledNeuronDirection(BaseLoss):
509+
"""
510+
Visualize a direction vector with an optional whitened activation vector to
511+
unstretch the activation space. Compared to the traditional Direction objectives,
512+
this objective places more emphasis on angle by optionally multiplying the dot
513+
product by the cosine similarity.
514+
515+
When cossim_pow is equal to 0, this objective works as a euclidean
516+
neuron objective. When cossim_pow is greater than 0, this objective works as a
517+
cosine similarity objective. An additional whitened neuron direction vector
518+
can optionally be supplied to improve visualization quality for some models.
519+
520+
More information on the algorithm this objective uses can be found here:
521+
https://github.com/tensorflow/lucid/issues/116
522+
523+
This Lucid equivalents of this loss function can be found here:
524+
https://github.com/tensorflow/lucid/blob/master/notebooks/
525+
activation-atlas/activation-atlas-simple.ipynb
526+
https://github.com/tensorflow/lucid/blob/master/notebooks/
527+
activation-atlas/class-activation-atlas.ipynb
528+
529+
Like the Lucid equivalents, our implementation differs slightly from the
530+
associated research paper.
531+
532+
Carter, et al., "Activation Atlas", Distill, 2019.
533+
https://distill.pub/2019/activation-atlas/
534+
"""
535+
536+
def __init__(
537+
self,
538+
target: torch.nn.Module,
539+
vec: torch.Tensor,
540+
vec_whitened: Optional[torch.Tensor] = None,
541+
cossim_pow: float = 4.0,
542+
x: Optional[int] = None,
543+
y: Optional[int] = None,
544+
eps: float = 1.0e-4,
545+
batch_index: Optional[int] = None,
546+
) -> None:
547+
"""
548+
Args:
549+
target (nn.Module): A target layer instance.
550+
vec (torch.Tensor): A neuron direction vector to use.
551+
vec_whitened (torch.Tensor, optional): A whitened neuron direction vector.
552+
cossim_pow (float, optional): The desired cosine similarity power to use.
553+
x (int, optional): Optionally provide a specific x position for the target
554+
neuron.
555+
y (int, optional): Optionally provide a specific y position for the target
556+
neuron.
557+
eps (float, optional): If cossim_pow is greater than zero, the desired
558+
epsilon value to use for cosine similarity calculations.
559+
"""
560+
BaseLoss.__init__(self, target, batch_index)
561+
self.vec = vec.unsqueeze(0) if vec.dim() == 1 else vec
562+
self.vec_whitened = vec_whitened
563+
self.cossim_pow = cossim_pow
564+
self.eps = eps
565+
self.x = x
566+
self.y = y
567+
if self.vec_whitened is not None:
568+
assert self.vec_whitened.dim() == 2
569+
assert self.vec.dim() == 2
570+
571+
def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
572+
activations = targets_to_values[self.target]
573+
activations = activations[self.batch_index[0] : self.batch_index[1]]
574+
assert activations.dim() == 4 or activations.dim() == 2
575+
assert activations.shape[1] == self.vec.shape[1]
576+
if activations.dim() == 4:
577+
_x, _y = get_neuron_pos(
578+
activations.size(2), activations.size(3), self.x, self.y
579+
)
580+
activations = activations[..., _x, _y]
581+
582+
vec = (
583+
torch.matmul(self.vec, self.vec_whitened)[0]
584+
if self.vec_whitened is not None
585+
else self.vec
586+
)
587+
if self.cossim_pow == 0:
588+
return activations * vec
589+
590+
dot = torch.mean(activations * vec)
591+
cossims = dot / (self.eps + torch.sqrt(torch.sum(activations ** 2)))
592+
return dot * torch.clamp(cossims, min=0.1) ** self.cossim_pow
593+
594+
502595
@loss_wrapper
503596
class TensorDirection(BaseLoss):
504597
"""
@@ -590,6 +683,47 @@ def __call__(self, targets_to_values: ModuleOutputMapping) -> torch.Tensor:
590683
return activations
591684

592685

686+
def sum_loss_list(
687+
loss_list: List,
688+
to_scalar_fn: Callable[[torch.Tensor], torch.Tensor] = torch.mean,
689+
) -> CompositeLoss:
690+
"""
691+
Summarize a large number of losses without recursion errors. By default using 300+
692+
loss functions for a single optimization task will result in exceeding Python's
693+
default maximum recursion depth limit. This function can be used to avoid the
694+
recursion depth limit for tasks such as summarizing a large list of loss functions
695+
with the built-in sum() function.
696+
697+
This function works similar to Lucid's optvis.objectives.Objective.sum() function.
698+
699+
Args:
700+
701+
loss_list (list): A list of loss function objectives.
702+
to_scalar_fn (Callable): A function for converting loss function outputs to
703+
scalar values, in order to prevent size mismatches.
704+
Default: torch.mean
705+
706+
Returns:
707+
loss_fn (CompositeLoss): A composite loss function containing all the loss
708+
functions from `loss_list`.
709+
"""
710+
711+
def loss_fn(module: ModuleOutputMapping) -> torch.Tensor:
712+
return sum([to_scalar_fn(loss(module)) for loss in loss_list])
713+
714+
name = "Sum(" + ", ".join([loss.__name__ for loss in loss_list]) + ")"
715+
# Collect targets from losses
716+
target = [
717+
target
718+
for targets in [
719+
[loss.target] if not hasattr(loss.target, "__iter__") else loss.target
720+
for loss in loss_list
721+
]
722+
for target in targets
723+
]
724+
return CompositeLoss(loss_fn, name=name, target=target)
725+
726+
593727
def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
594728
"""
595729
Helper function to summarize tensor outputs from loss functions.
@@ -617,7 +751,9 @@ def default_loss_summarize(loss_value: torch.Tensor) -> torch.Tensor:
617751
"Alignment",
618752
"Direction",
619753
"NeuronDirection",
754+
"AngledNeuronDirection",
620755
"TensorDirection",
621756
"ActivationWeights",
757+
"sum_loss_list",
622758
"default_loss_summarize",
623759
]

captum/optim/_param/image/transforms.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,152 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
384384
return self.translate_tensor(input, insets)
385385

386386

387+
class RandomRotation(nn.Module):
388+
"""
389+
Apply random rotation transforms on a NCHW tensor, using a sequence of degrees or
390+
torch.distributions instance.
391+
"""
392+
393+
__constants__ = [
394+
"degrees",
395+
"mode",
396+
"padding_mode",
397+
"align_corners",
398+
"_has_align_corners",
399+
"_is_distribution",
400+
]
401+
402+
def __init__(
403+
self,
404+
degrees: NumSeqOrTensorType,
405+
mode: str = "bilinear",
406+
padding_mode: str = "zeros",
407+
align_corners: bool = False,
408+
) -> None:
409+
"""
410+
Args:
411+
412+
degrees (float, sequence, or torch.distribution): Tuple of degrees values
413+
to randomly select from, or a torch.distributions instance.
414+
mode (str, optional): Interpolation mode to use. See documentation of
415+
F.grid_sample for more details. One of; "bilinear", "nearest", or
416+
"bicubic".
417+
Default: "bilinear"
418+
padding_mode (str, optional): Padding mode for values that fall outside of
419+
the grid. See documentation of F.grid_sample for more details. One of;
420+
"zeros", "border", or "reflection".
421+
Default: "zeros"
422+
align_corners (bool, optional): Whether or not to align corners. See
423+
documentation of F.affine_grid & F.grid_sample for more details.
424+
Default: False
425+
"""
426+
super().__init__()
427+
if isinstance(degrees, torch.distributions.distribution.Distribution):
428+
# Distributions are not supported by TorchScript / JIT yet
429+
assert degrees.batch_shape == torch.Size([])
430+
self.degrees_distribution = degrees
431+
self._is_distribution = True
432+
self.degrees = []
433+
else:
434+
assert hasattr(degrees, "__iter__")
435+
if torch.is_tensor(degrees):
436+
assert cast(torch.Tensor, degrees).dim() == 1
437+
degrees = degrees.tolist()
438+
assert len(degrees) > 0
439+
self.degrees = [float(d) for d in degrees]
440+
self._is_distribution = False
441+
442+
self.mode = mode
443+
self.padding_mode = padding_mode
444+
self.align_corners = align_corners
445+
self._has_align_corners = torch.__version__ >= "1.3.0"
446+
447+
def _get_rot_mat(
448+
self,
449+
theta: float,
450+
device: torch.device,
451+
dtype: torch.dtype,
452+
) -> torch.Tensor:
453+
"""
454+
Create a rotation matrix tensor.
455+
456+
Args:
457+
458+
theta (float): The rotation value in degrees.
459+
460+
Returns:
461+
**rot_mat** (torch.Tensor): A rotation matrix.
462+
"""
463+
theta = theta * math.pi / 180.0
464+
rot_mat = torch.tensor(
465+
[
466+
[math.cos(theta), -math.sin(theta), 0.0],
467+
[math.sin(theta), math.cos(theta), 0.0],
468+
],
469+
device=device,
470+
dtype=dtype,
471+
)
472+
return rot_mat
473+
474+
def _rotate_tensor(self, x: torch.Tensor, theta: float) -> torch.Tensor:
475+
"""
476+
Rotate an NCHW image tensor based on a specified degree value.
477+
478+
Args:
479+
480+
x (torch.Tensor): The NCHW image tensor to rotate.
481+
theta (float): The amount to rotate the NCHW image, in degrees.
482+
483+
Returns:
484+
**x** (torch.Tensor): A rotated NCHW image tensor.
485+
"""
486+
rot_matrix = self._get_rot_mat(theta, x.device, x.dtype)[None, ...].repeat(
487+
x.shape[0], 1, 1
488+
)
489+
if self._has_align_corners:
490+
# Pass align_corners explicitly for torch >= 1.3.0
491+
grid = F.affine_grid(rot_matrix, x.size(), align_corners=self.align_corners)
492+
x = F.grid_sample(
493+
x,
494+
grid,
495+
mode=self.mode,
496+
padding_mode=self.padding_mode,
497+
align_corners=self.align_corners,
498+
)
499+
else:
500+
grid = F.affine_grid(rot_matrix, x.size())
501+
x = F.grid_sample(x, grid, mode=self.mode, padding_mode=self.padding_mode)
502+
return x
503+
504+
def forward(self, x: torch.Tensor) -> torch.Tensor:
505+
"""
506+
Randomly rotate an NCHW image tensor.
507+
508+
Args:
509+
510+
x (torch.Tensor): NCHW image tensor to randomly rotate.
511+
512+
Returns:
513+
**x** (torch.Tensor): A randomly rotated NCHW image *tensor*.
514+
"""
515+
assert x.dim() == 4
516+
if self._is_distribution:
517+
rotate_angle = float(self.degrees_distribution.sample().item())
518+
else:
519+
n = int(
520+
torch.randint(
521+
low=0,
522+
high=len(self.degrees),
523+
size=[1],
524+
dtype=torch.int64,
525+
layout=torch.strided,
526+
device=x.device,
527+
).item()
528+
)
529+
rotate_angle = self.degrees[n]
530+
return self._rotate_tensor(x, rotate_angle)
531+
532+
387533
class ScaleInputRange(nn.Module):
388534
"""
389535
Multiplies the input by a specified multiplier for models with input ranges other
@@ -673,6 +819,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
673819
"center_crop",
674820
"RandomScale",
675821
"RandomSpatialJitter",
822+
"RandomRotation",
676823
"ScaleInputRange",
677824
"RGBToBGR",
678825
"GaussianSmoothing",

0 commit comments

Comments
 (0)