Skip to content

Commit a1007a8

Browse files
jgonikfacebook-github-bot
authored andcommitted
Mask for Adversarial Attacks (#1043)
Summary: Adds an optional mask argument to the FGSM and PGD adversarial attacks. This mask determines which pixels are affected by the adversarial perturbations. If no mask is specified, then all pixels are affected. This PR resolves #941. Pull Request resolved: #1043 Reviewed By: NarineK Differential Revision: D40442870 Pulled By: vivekmig fbshipit-source-id: f9f0688519f61f1520e949e3d7a8cb1cfc0342f3
1 parent 7c228ac commit a1007a8

File tree

4 files changed

+203
-12
lines changed

4 files changed

+203
-12
lines changed

captum/robust/_core/fgsm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, Optional, Tuple
2+
from typing import Any, Callable, Optional, Tuple, Union
33

44
import torch
55
from captum._utils.common import (
@@ -82,6 +82,7 @@ def perturb(
8282
target: Any,
8383
additional_forward_args: Any = None,
8484
targeted: bool = False,
85+
mask: Optional[TensorOrTupleOfTensorsGeneric] = None,
8586
) -> TensorOrTupleOfTensorsGeneric:
8687
r"""
8788
This method computes and returns the perturbed input for each input tensor.
@@ -130,6 +131,12 @@ def perturb(
130131
Default: None.
131132
targeted (bool, optional): If attack should be targeted.
132133
Default: False.
134+
mask (Tensor or tuple[Tensor, ...], optional): mask of zeroes and ones
135+
that defines which elements within the input tensor(s) are
136+
perturbed. This mask must have the same shape and
137+
dimensionality as the inputs. If this argument is not
138+
provided, all elements will be perturbed.
139+
Default: None.
133140
134141
135142
Returns:
@@ -144,6 +151,11 @@ def perturb(
144151
"""
145152
is_inputs_tuple = _is_tuple(inputs)
146153
inputs: Tuple[Tensor, ...] = _format_tensor_into_tuples(inputs)
154+
masks: Union[Tuple[int, ...], Tuple[Tensor, ...]] = (
155+
_format_tensor_into_tuples(mask)
156+
if (mask is not None)
157+
else (1,) * len(inputs)
158+
)
147159
gradient_mask = apply_gradient_requirements(inputs)
148160

149161
def _forward_with_loss() -> Tensor:
@@ -161,7 +173,7 @@ def _forward_with_loss() -> Tensor:
161173

162174
grads = compute_gradients(_forward_with_loss, inputs)
163175
undo_gradient_requirements(inputs, gradient_mask)
164-
perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted)
176+
perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted, masks)
165177
perturbed_inputs = tuple(
166178
self.bound(perturbed_inputs[i]) for i in range(len(perturbed_inputs))
167179
)
@@ -173,6 +185,7 @@ def _perturb(
173185
grads: Tuple,
174186
epsilon: float,
175187
targeted: bool,
188+
masks: Tuple,
176189
) -> Tuple:
177190
r"""
178191
A helper function to calculate the perturbed inputs given original
@@ -183,9 +196,9 @@ def _perturb(
183196
inputs = tuple(
184197
torch.where(
185198
torch.abs(grad) > self.zero_thresh,
186-
inp + multiplier * epsilon * torch.sign(grad),
199+
inp + multiplier * epsilon * torch.sign(grad) * mask,
187200
inp,
188201
)
189-
for grad, inp in zip(grads, inputs)
202+
for grad, inp, mask in zip(grads, inputs, masks)
190203
)
191204
return inputs

captum/robust/_core/pgd.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable
2+
from typing import Any, Callable, Optional, Tuple, Union
33

44
import torch
55
import torch.nn.functional as F
@@ -78,6 +78,7 @@ def perturb(
7878
targeted: bool = False,
7979
random_start: bool = False,
8080
norm: str = "Linf",
81+
mask: Optional[TensorOrTupleOfTensorsGeneric] = None,
8182
) -> TensorOrTupleOfTensorsGeneric:
8283
r"""
8384
This method computes and returns the perturbed input for each input tensor.
@@ -134,6 +135,12 @@ def perturb(
134135
norm (str, optional): Specifies the norm to calculate distance from
135136
original inputs: ``Linf`` | ``L2``.
136137
Default: ``Linf``
138+
mask (Tensor or tuple[Tensor, ...], optional): mask of zeroes and ones
139+
that defines which elements within the input tensor(s) are
140+
perturbed. This mask must have the same shape and
141+
dimensionality as the inputs. If this argument is not
142+
provided, all elements are perturbed.
143+
Default: None.
137144
138145
Returns:
139146
@@ -157,15 +164,29 @@ def _clip(inputs: Tensor, outputs: Tensor) -> Tensor:
157164

158165
is_inputs_tuple = _is_tuple(inputs)
159166
formatted_inputs = _format_tensor_into_tuples(inputs)
167+
formatted_masks: Union[Tuple[int, ...], Tuple[Tensor, ...]] = (
168+
_format_tensor_into_tuples(mask)
169+
if (mask is not None)
170+
else (1,) * len(formatted_inputs)
171+
)
160172
perturbed_inputs = formatted_inputs
161173
if random_start:
162174
perturbed_inputs = tuple(
163-
self.bound(self._random_point(formatted_inputs[i], radius, norm))
175+
self.bound(
176+
self._random_point(
177+
formatted_inputs[i], radius, norm, formatted_masks[i]
178+
)
179+
)
164180
for i in range(len(formatted_inputs))
165181
)
166182
for _i in range(step_num):
167183
perturbed_inputs = self.fgsm.perturb(
168-
perturbed_inputs, step_size, target, additional_forward_args, targeted
184+
perturbed_inputs,
185+
step_size,
186+
target,
187+
additional_forward_args,
188+
targeted,
189+
formatted_masks,
169190
)
170191
perturbed_inputs = tuple(
171192
_clip(formatted_inputs[j], perturbed_inputs[j])
@@ -178,7 +199,9 @@ def _clip(inputs: Tensor, outputs: Tensor) -> Tensor:
178199
)
179200
return _format_output(is_inputs_tuple, perturbed_inputs)
180201

181-
def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
202+
def _random_point(
203+
self, center: Tensor, radius: float, norm: str, mask: Union[Tensor, int]
204+
) -> Tensor:
182205
r"""
183206
A helper function that returns a uniform random point within the ball
184207
with the given center and radius. Norm should be either L2 or Linf.
@@ -190,9 +213,9 @@ def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor:
190213
r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius
191214
r = r[(...,) + (None,) * (r.dim() - 1)]
192215
x = r * unit_u
193-
return center + x
216+
return center + (x * mask)
194217
elif norm == "Linf":
195218
x = torch.rand_like(center) * radius * 2 - radius
196-
return center + x
219+
return center + (x * mask)
197220
else:
198221
raise AssertionError("Norm constraint must be L2 or Linf.")

tests/robust/test_FGSM.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, List, Tuple, Union
2+
from typing import Any, Callable, List, Optional, Tuple, Union
33

44
import torch
55
from captum._utils.typing import TensorLikeList, TensorOrTupleOfTensorsGeneric
@@ -128,6 +128,60 @@ def test_attack_bound(self) -> None:
128128
upper_bound=5.0,
129129
)
130130

131+
def test_attack_masked_tensor(self) -> None:
132+
model = BasicModel()
133+
input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]], requires_grad=True)
134+
mask = torch.tensor([[1, 0, 0, 1, 1]])
135+
self._FGSM_assert(
136+
model, input, 1, 0.1, [[2.0, -9.0, 9.0, 1.0, -3.0]], mask=mask
137+
)
138+
139+
def test_attack_masked_multiinput(self) -> None:
140+
model = BasicModel2()
141+
input1 = torch.tensor([[4.0, -1.0], [3.0, 10.0]], requires_grad=True)
142+
input2 = torch.tensor([[2.0, -5.0], [-2.0, 1.0]], requires_grad=True)
143+
mask1 = torch.tensor([[1, 0], [1, 0]])
144+
mask2 = torch.tensor([[0, 0], [0, 0]])
145+
self._FGSM_assert(
146+
model,
147+
(input1, input2),
148+
0,
149+
0.25,
150+
([[3.75, -1.0], [2.75, 10.0]], [[2.0, -5.0], [-2.0, 1.0]]),
151+
mask=(mask1, mask2),
152+
)
153+
154+
def test_attack_masked_loss_defined(self) -> None:
155+
model = BasicModel_MultiLayer()
156+
add_input = torch.tensor([[-1.0, 2.0, 2.0]])
157+
input = torch.tensor([[1.0, 6.0, -3.0]])
158+
labels = torch.tensor([0])
159+
mask = torch.tensor([[0, 0, 1]])
160+
loss_func = CrossEntropyLoss(reduction="none")
161+
adv = FGSM(model, loss_func)
162+
perturbed_input = adv.perturb(
163+
input, 0.2, labels, additional_forward_args=(add_input,), mask=mask
164+
)
165+
assertTensorAlmostEqual(
166+
self, perturbed_input, [[1.0, 6.0, -3.0]], delta=0.01, mode="max"
167+
)
168+
169+
def test_attack_masked_bound(self) -> None:
170+
model = BasicModel()
171+
input = torch.tensor([[9.0, 10.0, -6.0, -1.0]])
172+
mask = torch.tensor([[1, 0, 1, 0]])
173+
self._FGSM_assert(
174+
model,
175+
input,
176+
3,
177+
0.2,
178+
[[5.0, 5.0, -5.0, -1.0]],
179+
targeted=True,
180+
lower_bound=-5.0,
181+
upper_bound=5.0,
182+
mask=mask,
183+
)
184+
131185
def _FGSM_assert(
132186
self,
133187
model: Callable,
@@ -139,10 +193,11 @@ def _FGSM_assert(
139193
additional_inputs: Any = None,
140194
lower_bound: float = float("-inf"),
141195
upper_bound: float = float("inf"),
196+
mask: Optional[TensorOrTupleOfTensorsGeneric] = None,
142197
) -> None:
143198
adv = FGSM(model, lower_bound=lower_bound, upper_bound=upper_bound)
144199
perturbed_input = adv.perturb(
145-
inputs, epsilon, target, additional_inputs, targeted
200+
inputs, epsilon, target, additional_inputs, targeted, mask
146201
)
147202
if isinstance(perturbed_input, Tensor):
148203
assertTensorAlmostEqual(

tests/robust/test_PGD.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,103 @@ def test_attack_random_start(self) -> None:
108108
)
109109
norm = torch.norm((perturbed_input - input).squeeze()).numpy()
110110
self.assertLessEqual(norm, 0.25)
111+
112+
def test_attack_masked_nontargeted(self) -> None:
113+
model = BasicModel()
114+
input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]])
115+
mask = torch.tensor([[1, 1, 0, 0, 0]])
116+
adv = PGD(model)
117+
perturbed_input = adv.perturb(input, 0.25, 0.1, 2, 4, mask=mask)
118+
assertTensorAlmostEqual(
119+
self,
120+
perturbed_input,
121+
[[2.0, -9.0, 9.0, 1.0, -3.0]],
122+
delta=0.01,
123+
mode="max",
124+
)
125+
126+
def test_attack_masked_targeted(self) -> None:
127+
model = BasicModel()
128+
input = torch.tensor([[9.0, 10.0, -6.0, -1.0]], requires_grad=True)
129+
mask = torch.tensor([[1, 1, 1, 0]])
130+
adv = PGD(model)
131+
perturbed_input = adv.perturb(input, 0.2, 0.1, 3, 3, targeted=True, mask=mask)
132+
assertTensorAlmostEqual(
133+
self,
134+
perturbed_input,
135+
[[9.0, 10.0, -6.0, -1.0]],
136+
delta=0.01,
137+
mode="max",
138+
)
139+
140+
def test_attack_masked_multiinput(self) -> None:
141+
model = BasicModel2()
142+
input1 = torch.tensor([[4.0, -1.0], [3.0, 10.0]], requires_grad=True)
143+
input2 = torch.tensor([[2.0, -5.0], [-2.0, 1.0]], requires_grad=True)
144+
mask1 = torch.tensor([[1, 1], [0, 0]])
145+
mask2 = torch.tensor([[0, 1], [0, 1]])
146+
adv = PGD(model)
147+
perturbed_input = adv.perturb(
148+
(input1, input2), 0.25, 0.1, 3, 0, norm="L2", mask=(mask1, mask2)
149+
)
150+
answer = ([[3.75, -1.0], [3.0, 10.0]], [[2.0, -5.0], [-2.0, 1.0]])
151+
for i in range(len(perturbed_input)):
152+
assertTensorAlmostEqual(
153+
self,
154+
perturbed_input[i],
155+
answer[i],
156+
delta=0.01,
157+
mode="max",
158+
)
159+
160+
def test_attack_masked_random_start(self) -> None:
161+
model = BasicModel()
162+
input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]])
163+
mask = torch.tensor([[1, 0, 1, 0, 1]])
164+
adv = PGD(model)
165+
perturbed_input = adv.perturb(
166+
input, 0.25, 0.1, 0, 4, random_start=True, mask=mask
167+
)
168+
assertTensorAlmostEqual(
169+
self,
170+
perturbed_input,
171+
[[2.0, -9.0, 9.0, 1.0, -3.0]],
172+
delta=0.25,
173+
mode="max",
174+
)
175+
perturbed_input = adv.perturb(
176+
input, 0.25, 0.1, 0, 4, norm="L2", random_start=True, mask=mask
177+
)
178+
norm = torch.norm((perturbed_input - input).squeeze()).numpy()
179+
self.assertLessEqual(norm, 0.25)
180+
181+
def test_attack_masked_3dimensional_input(self) -> None:
182+
model = BasicModel()
183+
input = torch.tensor(
184+
[[[4.0, 2.0], [-1.0, -2.0]], [[3.0, -4.0], [10.0, 5.0]]], requires_grad=True
185+
)
186+
mask = torch.tensor([[[1, 0], [0, 1]], [[1, 0], [1, 1]]])
187+
adv = PGD(model)
188+
perturbed_input = adv.perturb(input, 0.25, 0.1, 3, (0, 1), mask=mask)
189+
assertTensorAlmostEqual(
190+
self,
191+
perturbed_input,
192+
[[[4.0, 2.0], [-1.0, -2.0]], [[3.0, -4.0], [10.0, 5.0]]],
193+
delta=0.01,
194+
mode="max",
195+
)
196+
197+
def test_attack_masked_loss_defined(self) -> None:
198+
model = BasicModel_MultiLayer()
199+
add_input = torch.tensor([[-1.0, 2.0, 2.0]])
200+
input = torch.tensor([[1.0, 6.0, -3.0]])
201+
mask = torch.tensor([[0, 1, 0]])
202+
labels = torch.tensor([0])
203+
loss_func = CrossEntropyLoss(reduction="none")
204+
adv = PGD(model, loss_func)
205+
perturbed_input = adv.perturb(
206+
input, 0.25, 0.1, 3, labels, additional_forward_args=(add_input,), mask=mask
207+
)
208+
assertTensorAlmostEqual(
209+
self, perturbed_input, [[1.0, 6.0, -3.0]], delta=0.01, mode="max"
210+
)

0 commit comments

Comments
 (0)