From d20505a3a12619066d8a52ca38a7715a29ba488d Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:11:03 +0330 Subject: [PATCH 1/7] Add the N-FGSM attack --- torchattacks/attacks/nfgsm.py | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 torchattacks/attacks/nfgsm.py diff --git a/torchattacks/attacks/nfgsm.py b/torchattacks/attacks/nfgsm.py new file mode 100644 index 00000000..25204a6e --- /dev/null +++ b/torchattacks/attacks/nfgsm.py @@ -0,0 +1,70 @@ +import torch +import torch.nn as nn + +from ..attack import Attack + + +class NFGSM(Attack): + r""" + Unclipped FGSM with noise proposed in 'Make Some Noise: Reliable and Efficient Single-Step Adversarial Training' + [https://arxiv.org/abs/2202.01181] + + Distance Measure : Linf + + Arguments: + model (nn.Module): model to attack. + eps (float): maximum perturbation. (Default: 8/255) + k (float): magnitude of the uniform noise. (Default: 16/255) + + Shape: + - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. + - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. + - output: :math:`(N, C, H, W)`. + + Examples:: + >>> attack = torchattacks.NFGSM(model, eps=8/255, k=16/255) + >>> adv_images = attack(images, labels) + """ + + def __init__(self, model, eps=8 / 255, k=16 / 255): + super().__init__("NFGSM", model) + self.eps = eps + self.k = k + self.supported_mode = ["default", "targeted"] + + def forward(self, images, labels): + r""" + Overridden. + """ + + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + if self.targeted: + target_labels = self.get_target_label(images, labels) + + loss = nn.CrossEntropyLoss() + + adv_images = images + torch.randn_like(images).uniform_( + -self.k, self.k + ) # nopep8 + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + adv_images.requires_grad = True + + outputs = self.get_logits(adv_images) + + # Calculate loss + if self.targeted: + cost = -loss(outputs, target_labels) + else: + cost = loss(outputs, labels) + + # Update adversarial images + grad = torch.autograd.grad( + cost, adv_images, retain_graph=False, create_graph=False + )[0] + + adv_images = adv_images + self.eps * grad.sign() + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + + return adv_images From 2e46eb82c4c81084cdcb25ee1c2961b22e3bbee9 Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:14:50 +0330 Subject: [PATCH 2/7] Add the N-FGSM attack --- torchattacks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchattacks/__init__.py b/torchattacks/__init__.py index 48f18a74..aa222bd5 100644 --- a/torchattacks/__init__.py +++ b/torchattacks/__init__.py @@ -9,6 +9,7 @@ from .attacks.pgd import PGD from .attacks.eotpgd import EOTPGD from .attacks.ffgsm import FFGSM +from .attacks.nfgsm import NFGSM from .attacks.tpgd import TPGD from .attacks.mifgsm import MIFGSM from .attacks.upgd import UPGD From 674c40e483d9b39f9e114419eb6f5682604cb6f7 Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:15:56 +0330 Subject: [PATCH 3/7] Add the N-FGSM attack --- torchattacks/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchattacks/__init__.py b/torchattacks/__init__.py index aa222bd5..f04fe61e 100644 --- a/torchattacks/__init__.py +++ b/torchattacks/__init__.py @@ -62,6 +62,7 @@ "PGD", "EOTPGD", "FFGSM", + "NFGSM", "TPGD", "MIFGSM", "UPGD", From 9ec7f8c7c73a69a392d9e4155d505a8a55078000 Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:30:40 +0330 Subject: [PATCH 4/7] Add N-FGSM attack --- README.md | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 79528a86..2f1fffc8 100644 --- a/README.md +++ b/README.md @@ -168,19 +168,20 @@ The distance measure in parentheses. | **SparseFool**
(L0) | SparseFool: a few pixels make a big difference ([Modas et al., 2019](https://arxiv.org/abs/1811.02248)) | | | **DIFGSM**
(Linf) | Improving Transferability of Adversarial Examples with Input Diversity ([Xie et al., 2019](https://arxiv.org/abs/1803.06978)) | :heart_eyes: Contributor [taobai](https://github.com/tao-bai) | | **TIFGSM**
(Linf) | Evading Defenses to Transferable Adversarial Examples by Translation-Invariant Attacks ([Dong et al., 2019](https://arxiv.org/abs/1904.02884)) | :heart_eyes: Contributor [taobai](https://github.com/tao-bai) | -| **NIFGSM**
(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | -| **SINIFGSM**
(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | -| **VMIFGSM**
(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | -| **VNIFGSM**
(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | +| **NIFGSM**
(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | +| **SINIFGSM**
(Linf) | Nesterov Accelerated Gradient and Scale Invariance for Adversarial Attacks ([Lin, et al., 2022](https://arxiv.org/abs/1908.06281)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | +| **VMIFGSM**
(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | +| **VNIFGSM**
(Linf) | Enhancing the Transferability of Adversarial Attacks through Variance Tuning ([Wang, et al., 2022](https://arxiv.org/abs/2103.15571)) | :heart_eyes: Contributor [Zhijin-Ge](https://github.com/Zhijin-Ge) | | **Jitter**
(Linf) | Exploring Misclassifications of Robust Neural Networks to Enhance Adversarial Attacks ([Schwinn, Leo, et al., 2021](https://arxiv.org/abs/2105.10304)) | | | **Pixle**
(L0) | Pixle: a fast and effective black-box attack based on rearranging pixels ([Pomponi, Jary, et al., 2022](https://arxiv.org/abs/2202.02236)) | | -| **LGV**
(Linf, L2, L1, L0) | LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity ([Gubri, et al., 2022](https://arxiv.org/abs/2207.13129)) | :heart_eyes: Contributor [Martin Gubri](https://github.com/Framartin) | -| **SPSA**
(Linf) | Adversarial Risk and the Dangers of Evaluating Against Weak Attacks ([Uesato, Jonathan, et al., 2018](https://arxiv.org/abs/1802.05666)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | -| **JSMA**
(L0) | The Limitations of Deep Learning in Adversarial Settings ([Papernot, Nicolas, et al., 2016](https://arxiv.org/abs/1511.07528v1)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | -| **EADL1**
(L1) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | -| **EADEN**
(L1, L2) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | -| **PIFGSM (PIM)**
(Linf) | Patch-wise Attack for Fooling Deep Neural Network ([Gao, Lianli, et al., 2020](https://arxiv.org/abs/2007.06765)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | -| **PIFGSM++ (PIM++)**
(Linf) | Patch-wise++ Perturbation for Adversarial Targeted Attacks ([Gao, Lianli, et al., 2021](https://arxiv.org/abs/2012.15503)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **LGV**
(Linf, L2, L1, L0) | LGV: Boosting Adversarial Example Transferability from Large Geometric Vicinity ([Gubri, et al., 2022](https://arxiv.org/abs/2207.13129)) | :heart_eyes: Contributor [Martin Gubri](https://github.com/Framartin) | +| **SPSA**
(Linf) | Adversarial Risk and the Dangers of Evaluating Against Weak Attacks ([Uesato, Jonathan, et al., 2018](https://arxiv.org/abs/1802.05666)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **JSMA**
(L0) | The Limitations of Deep Learning in Adversarial Settings ([Papernot, Nicolas, et al., 2016](https://arxiv.org/abs/1511.07528v1)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **EADL1**
(L1) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **EADEN**
(L1, L2) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **PIFGSM (PIM)**
(Linf) | Patch-wise Attack for Fooling Deep Neural Network ([Gao, Lianli, et al., 2020](https://arxiv.org/abs/2007.06765)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **PIFGSM++ (PIM++)**
(Linf)| Patch-wise++ Perturbation for Adversarial Targeted Attacks ([Gao, Lianli, et al., 2021](https://arxiv.org/abs/2012.15503)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **NFGSM**
(Linf) | Make Some Noise: Reliable and Efficient Single-Step Adversarial Training ([Jorge, et al., 2022](https://arxiv.org/abs/2202.01181)) | :heart_eyes: Contributor [Ramtin Moslemi](https://github.com/ramtinmoslemi) | From 6af7b3a20287c854ce70605ce312949951a6028e Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:48:59 +0330 Subject: [PATCH 5/7] Add the ZeroGrad attack --- torchattacks/attacks/zerograd.py | 86 ++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 torchattacks/attacks/zerograd.py diff --git a/torchattacks/attacks/zerograd.py b/torchattacks/attacks/zerograd.py new file mode 100644 index 00000000..291a112b --- /dev/null +++ b/torchattacks/attacks/zerograd.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn + +from ..attack import Attack + + +class ZeroGrad(Attack): + r""" + ZeroGrad in the paper 'ZeroGrad : Mitigating and Explaining Catastrophic Overfitting in FGSM Adversarial Training' + [https://arxiv.org/abs/2103.15476] + + Distance Measure : Linf + + Arguments: + model (nn.Module): model to attack. + eps (float): maximum perturbation. (Default: 8/255) + alpha (float): step size. (Default: 16/255) + qval (float): quantile which gradients would become zero in zerograd (Default: 0.35) + steps (int): number of zerograd iterations. (Default: 1) + random_start (bool): using random initialization of delta. (Default: True) + + Shape: + - images: :math:`(N, C, H, W)` where `N = number of batches`, `C = number of channels`, `H = height` and `W = width`. It must have a range [0, 1]. + - labels: :math:`(N)` where each value :math:`y_i` is :math:`0 \leq y_i \leq` `number of labels`. + - output: :math:`(N, C, H, W)`. + + Examples:: + >>> attack = torchattacks.ZeroGrad(model, eps=8/255, alpha=16/255, qval=0.35, steps=1, random_start=True) + >>> adv_images = attack(images, labels) + + """ + + def __init__(self, model, eps=8 / 255, alpha=16 / 255, qval=0.35, steps=1, random_start=True): + super().__init__("ZeroGrad", model) + self.eps = eps + self.alpha = alpha + self.steps = steps + self.random_start = random_start + self.qval = qval + self.supported_mode = ["default", "targeted"] + + def forward(self, images, labels): + r""" + Overridden. + """ + + images = images.clone().detach().to(self.device) + labels = labels.clone().detach().to(self.device) + + if self.targeted: + target_labels = self.get_target_label(images, labels) + + loss = nn.CrossEntropyLoss() + adv_images = images.clone().detach() + + if self.random_start: + # Starting at a uniformly random point + adv_images = adv_images + torch.empty_like(adv_images).uniform_( + -self.eps, self.eps + ) + adv_images = torch.clamp(adv_images, min=0, max=1).detach() + + for _ in range(self.steps): + adv_images.requires_grad = True + outputs = self.get_logits(adv_images) + + # Calculate loss + if self.targeted: + cost = -loss(outputs, target_labels) + else: + cost = loss(outputs, labels) + + # Zero out the small values in the gradient + grad = torch.autograd.grad( + cost, adv_images, retain_graph=False, create_graph=False + )[0] + q_grad = torch.quantile(torch.abs(grad).view(grad.size(0), -1), self.q_val, dim=1) + grad[torch.abs(grad) < q_grad.view(grad.size(0), 1, 1, 1)] = 0 + + adv_images = adv_images.detach() + self.alpha * grad.sign() + delta = torch.clamp(adv_images - images, min=-self.eps, max=self.eps) + adv_images = torch.clamp(images + delta, min=0, max=1).detach() + + grad = delta.grad.detach() + + return adv_images From fac1672a80a1ef819b5230f0436a0ff4d9e75b25 Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:50:43 +0330 Subject: [PATCH 6/7] Add the ZeroGrad attack --- torchattacks/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchattacks/__init__.py b/torchattacks/__init__.py index f04fe61e..f2846488 100644 --- a/torchattacks/__init__.py +++ b/torchattacks/__init__.py @@ -9,6 +9,7 @@ from .attacks.pgd import PGD from .attacks.eotpgd import EOTPGD from .attacks.ffgsm import FFGSM +from .attacks.zerograd import ZeroGrad from .attacks.nfgsm import NFGSM from .attacks.tpgd import TPGD from .attacks.mifgsm import MIFGSM @@ -62,6 +63,7 @@ "PGD", "EOTPGD", "FFGSM", + "ZeroGrad", "NFGSM", "TPGD", "MIFGSM", From 001b8db43f110dff3eabb2678f2c7e1280316a72 Mon Sep 17 00:00:00 2001 From: Ramtin Moslemi <76493699+RamtinMoslemi@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:55:22 +0330 Subject: [PATCH 7/7] Add the ZeroGrad attack --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f1fffc8..fc38f102 100644 --- a/README.md +++ b/README.md @@ -181,10 +181,10 @@ The distance measure in parentheses. | **EADEN**
(L1, L2) | EAD: Elastic-Net Attacks to Deep Neural Networks ([Chen, Pin-Yu, et al., 2018](https://arxiv.org/abs/1709.04114)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | | **PIFGSM (PIM)**
(Linf) | Patch-wise Attack for Fooling Deep Neural Network ([Gao, Lianli, et al., 2020](https://arxiv.org/abs/2007.06765)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | | **PIFGSM++ (PIM++)**
(Linf)| Patch-wise++ Perturbation for Adversarial Targeted Attacks ([Gao, Lianli, et al., 2021](https://arxiv.org/abs/2012.15503)) | :heart_eyes: Contributor [Riko Naka](https://github.com/rikonaka) | +| **ZeroGrad**
(Linf) | ZeroGrad: Mitigating and Explaining Catastrophic Overfitting in FGSM Adversarial Training ([Golgooni, et al., 2021](https://arxiv.org/abs/2103.15476)) | :heart_eyes: Contributor [Ramtin Moslemi](https://github.com/ramtinmoslemi) | | **NFGSM**
(Linf) | Make Some Noise: Reliable and Efficient Single-Step Adversarial Training ([Jorge, et al., 2022](https://arxiv.org/abs/2202.01181)) | :heart_eyes: Contributor [Ramtin Moslemi](https://github.com/ramtinmoslemi) | - ## :bar_chart: Performance Comparison As for the comparison packages, currently updated and the most cited methods were selected: