Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,21 @@ The distance measure in parentheses.
| **SparseFool**<br />(L0) | SparseFool: a few pixels make a big difference ([Modas et al., 2019](https://arxiv.org/abs/1811.02248)) | |
| **DIFGSM**<br />(Linf) | Improving Transferability of Adversarial Examples with Input Diversity ([Xie et al., 2019](https://arxiv.org/abs/1803.06978)) | :heart_eyes: Contributor [taobai](https://github.com/tao-bai) |
| **TIFGSM**<br />(Linf) | Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks ([Dong et al., 2019](https://arxiv.org/abs/1904.02884)) | :heart_eyes: Contributor [taobai](https://github.com/tao-bai) |
| **NIFGSM**<br />(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **SINIFGSM**<br />(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **VMIFGSM**<br />(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **VNIFGSM**<br />(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **NIFGSM**<br />(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **SINIFGSM**<br />(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **VMIFGSM**<br />(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **VNIFGSM**<br />(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) |
| **Jitter**<br />(Linf) | Exploring Misclassifications of Robust Neural Networks to Enhance Adversarial Attacks ([Schwinn, Leo, et al., 2021](https://arxiv.org/abs/2105.10304)) | |
| **Pixle**<br />(L0) | Pixle: a fast and effective black-box attack based on rearranging pixels ([Pomponi, Jary, et al., 2022](https://arxiv.org/abs/2202.02236)) | |
| **LGV**<br />(Linf, L2, L1, L0) | LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity ([Gubri, et al., 2022](https://arxiv.org/abs/2207.13129)) | :heart_eyes: Contributor [Martin Gubri](https://github.com/Framartin) |
| **SPSA**<br />(Linf) | Adversarial Risk and the Dangers of Evaluating Against Weak Attacks ([Uesato, Jonathan, et al., 2018](https://arxiv.org/abs/1802.05666)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **JSMA**<br />(L0) | The Limitations of Deep Learning in Adversarial Settings ([Papernot, Nicolas, et al., 2016](https://arxiv.org/abs/1511.07528v1)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **EADL1**<br />(L1) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **EADEN**<br />(L1, L2) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **PIFGSM (PIM)**<br />(Linf) | Patch-wise Attack for Fooling Deep Neural Network ([Gao, Lianli, et al., 2020](https://arxiv.org/abs/2007.06765)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **PIFGSM++ (PIM++)**<br />(Linf) | Patch-wise++ Perturbation for Adversarial Targeted Attacks ([Gao, Lianli, et al., 2021](https://arxiv.org/abs/2012.15503)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |

| **LGV**<br />(Linf, L2, L1, L0) | LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity ([Gubri, et al., 2022](https://arxiv.org/abs/2207.13129)) | :heart_eyes: Contributor [Martin Gubri](https://github.com/Framartin) |
| **SPSA**<br />(Linf) | Adversarial Risk and the Dangers of Evaluating Against Weak Attacks ([Uesato, Jonathan, et al., 2018](https://arxiv.org/abs/1802.05666)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **JSMA**<br />(L0) | The Limitations of Deep Learning in Adversarial Settings ([Papernot, Nicolas, et al., 2016](https://arxiv.org/abs/1511.07528v1)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **EADL1**<br />(L1) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **EADEN**<br />(L1, L2) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **PIFGSM (PIM)**<br />(Linf) | Patch-wise Attack for Fooling Deep Neural Network ([Gao, Lianli, et al., 2020](https://arxiv.org/abs/2007.06765)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **PIFGSM++ (PIM++)**<br />(Linf)| Patch-wise++ Perturbation for Adversarial Targeted Attacks ([Gao, Lianli, et al., 2021](https://arxiv.org/abs/2012.15503)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) |
| **ZeroGrad**<br />(Linf) | ZeroGrad: Mitigating and Explaining Catastrophic Overfitting in FGSM Adversarial Training ([Golgooni, et al., 2021](https://arxiv.org/abs/2103.15476)) | :heart_eyes: Contributor [Ramtin Moslemi](https://github.com/ramtinmoslemi) |
| **NFGSM**<br />(Linf) | Make Some Noise: Reliable and Efficient Single-Step Adversarial Training ([Jorge, et al., 2022](https://arxiv.org/abs/2202.01181)) | :heart_eyes: Contributor [Ramtin Moslemi](https://github.com/ramtinmoslemi) |


## :bar_chart: Performance Comparison
Expand Down
4 changes: 4 additions & 0 deletions torchattacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .attacks.pgd import PGD
from .attacks.eotpgd import EOTPGD
from .attacks.ffgsm import FFGSM
from .attacks.zerograd import ZeroGrad
from .attacks.nfgsm import NFGSM
from .attacks.tpgd import TPGD
from .attacks.mifgsm import MIFGSM
from .attacks.upgd import UPGD
Expand Down Expand Up @@ -61,6 +63,8 @@
"PGD",
"EOTPGD",
"FFGSM",
"ZeroGrad",
"NFGSM",
"TPGD",
"MIFGSM",
"UPGD",
Expand Down
70 changes: 70 additions & 0 deletions torchattacks/attacks/nfgsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch.nn as nn

from ..attack import Attack


class NFGSM(Attack):
r"""
Unclipped FGSM with noise proposed in 'Make Some Noise: Reliable and Efficient Single-Step Adversarial Training'
[https://arxiv.org/abs/2202.01181]

Distance Measure : Linf

Arguments:
model (nn.Module): model to attack.
eps (float): maximum perturbation. (Default: 8/255)
k (float): magnitude of the uniform noise. (Default: 16/255)

Shape:
- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1].
- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
- output: :math:`(N, C, H, W)`.

Examples::
>>> attack = torchattacks.NFGSM(model, eps=8/255, k=16/255)
>>> adv_images = attack(images, labels)
"""

def __init__(self, model, eps=8 / 255, k=16 / 255):
super().__init__("NFGSM", model)
self.eps = eps
self.k = k
self.supported_mode = ["default", "targeted"]

def forward(self, images, labels):
r"""
Overridden.
"""

images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)

if self.targeted:
target_labels = self.get_target_label(images, labels)

loss = nn.CrossEntropyLoss()

adv_images = images + torch.randn_like(images).uniform_(
-self.k, self.k
) # nopep8
adv_images = torch.clamp(adv_images, min=0, max=1).detach()
adv_images.requires_grad = True

outputs = self.get_logits(adv_images)

# Calculate loss
if self.targeted:
cost = -loss(outputs, target_labels)
else:
cost = loss(outputs, labels)

# Update adversarial images
grad = torch.autograd.grad(
cost, adv_images, retain_graph=False, create_graph=False
)[0]

adv_images = adv_images + self.eps * grad.sign()
adv_images = torch.clamp(adv_images, min=0, max=1).detach()

return adv_images
86 changes: 86 additions & 0 deletions torchattacks/attacks/zerograd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn

from ..attack import Attack


class ZeroGrad(Attack):
r"""
ZeroGrad in the paper 'ZeroGrad : Mitigating and Explaining Catastrophic Overfitting in FGSM Adversarial Training'
[https://arxiv.org/abs/2103.15476]

Distance Measure : Linf

Arguments:
model (nn.Module): model to attack.
eps (float): maximum perturbation. (Default: 8/255)
alpha (float): step size. (Default: 16/255)
qval (float): quantile which gradients would become zero in zerograd (Default: 0.35)
steps (int): number of zerograd iterations. (Default: 1)
random_start (bool): using random initialization of delta. (Default: True)

Shape:
- images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1].
- labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`.
- output: :math:`(N, C, H, W)`.

Examples::
>>> attack = torchattacks.ZeroGrad(model, eps=8/255, alpha=16/255, qval=0.35, steps=1, random_start=True)
>>> adv_images = attack(images, labels)

"""

def __init__(self, model, eps=8 / 255, alpha=16 / 255, qval=0.35, steps=1, random_start=True):
super().__init__("ZeroGrad", model)
self.eps = eps
self.alpha = alpha
self.steps = steps
self.random_start = random_start
self.qval = qval
self.supported_mode = ["default", "targeted"]

def forward(self, images, labels):
r"""
Overridden.
"""

images = images.clone().detach().to(self.device)
labels = labels.clone().detach().to(self.device)

if self.targeted:
target_labels = self.get_target_label(images, labels)

loss = nn.CrossEntropyLoss()
adv_images = images.clone().detach()

if self.random_start:
# Starting at a uniformly random point
adv_images = adv_images + torch.empty_like(adv_images).uniform_(
-self.eps, self.eps
)
adv_images = torch.clamp(adv_images, min=0, max=1).detach()

for _ in range(self.steps):
adv_images.requires_grad = True
outputs = self.get_logits(adv_images)

# Calculate loss
if self.targeted:
cost = -loss(outputs, target_labels)
else:
cost = loss(outputs, labels)

# Zero out the small values in the gradient
grad = torch.autograd.grad(
cost, adv_images, retain_graph=False, create_graph=False
)[0]
q_grad = torch.quantile(torch.abs(grad).view(grad.size(0), -1), self.q_val, dim=1)
grad[torch.abs(grad) < q_grad.view(grad.size(0), 1, 1, 1)] = 0

adv_images = adv_images.detach() + self.alpha * grad.sign()
delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps)
adv_images = torch.clamp(images + delta, min=0, max=1).detach()

grad = delta.grad.detach()

return adv_images