Skip to content

Commit 045d116

Browse files
🚀 feat(model): Updated SuperSimpleNet to latest version (#3036)
* Fix squeeze on 1dim score Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Add option to train without masks Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Add JIMS separate feat extension Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Update docs for JIMS extension Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Remove unused get_params method Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Add unit tests for SSN Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> * Rename vars and update metrics in readme Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> --------- Signed-off-by: blaz.rolih <blaz.rolih@fri.uni-lj.si> Co-authored-by: Rajesh Gangireddy <rajesh.gangireddy@intel.com>
1 parent 8dca8cc commit 045d116

File tree

7 files changed

+164
-69
lines changed

7 files changed

+164
-69
lines changed
1.11 MB
Loading

docs/source/markdown/guides/reference/models/image/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection
113113
:link: ./supersimplenet
114114
:link-type: doc
115115

116-
SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection
116+
SuperSimpleNet: A Unified Surface Defect Detection Model for all Supervision Regimes
117117
:::
118118

119119
:::{grid-item-card} {material-regular}`model_training;1.5em` U-Flow
Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
# SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection
1+
# SuperSimpleNet
22

3-
This is an implementation of the [SuperSimpleNet](https://arxiv.org/pdf/2408.03143) paper, based on the [official code](https://github.com/blaz-r/SuperSimpleNet).
3+
This is an implementation of the SuperSimpleNet, based on the [official code](https://github.com/blaz-r/SuperSimpleNet).
4+
5+
The model was first presented at ICPR 2024: [SuperSimpleNet : Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection](https://arxiv.org/abs/2408.03143)
6+
7+
An extension was later published in JIMS 2025: [No Label Left Behind: A Unified Surface Defect Detection Model for all Supervision Regimes](https://link.springer.com/article/10.1007/s10845-025-02680-8)
48

59
Model Type: Segmentation
610

@@ -11,7 +15,7 @@ feature extractor with upscaling, feature adaptor, feature-level synthetic anoma
1115
segmentation-detection module.
1216

1317
A ResNet-like feature extractor first extracts features, which are then upscaled and
14-
average-pooled to capture neighboring context. Features are further refined for anomaly detection task in the adaptor module.
18+
average-pooled to capture neighboring context. Features are (optionally) further refined for anomaly detection task in the adaptor module.
1519
During training, synthetic anomalies are generated at the feature level by adding Gaussian noise to regions defined by the
1620
binary Perlin noise mask. The perturbed features are then fed into the segmentation-detection
1721
module, which produces the anomaly map and the anomaly score. During inference, anomaly generation is skipped, and the model
@@ -24,6 +28,9 @@ This implementation supports both unsupervised and supervised setting, but Anoma
2428

2529
![SuperSimpleNet architecture](/docs/source/images/supersimplenet/architecture.png "SuperSimpleNet architecture")
2630

31+
Currently, the difference between ICPR and JIMS code is only the `adapt_cls_features` which controls whether the features used for classification head are adapted or not.
32+
For ICPR this is set to True (i.e. the features for classification head are adapted), and for JIMS version this is False (which is also the default).
33+
2734
## Usage
2835

2936
`anomalib train --model SuperSimpleNet --data MVTecAD --data.category <category>`
@@ -36,29 +43,29 @@ This implementation supports both unsupervised and supervised setting, but Anoma
3643
>
3744
> It is recommended to train the model for 300 epochs with batch size of 32 to achieve stable training with random anomaly generation. Training with lower parameter values will still work, but might not yield the optimal results.
3845
>
39-
> For supervised learning, refer to the [official code](https://github.com/blaz-r/SuperSimpleNet).
46+
> For weakly, mixed and fully supervised training, refer to the [official code](https://github.com/blaz-r/SuperSimpleNet).
4047
4148
## MVTecAD AD results
4249

4350
The following results were obtained using this Anomalib implementation trained for 300 epochs with seed 0, default params, and batch size 32.
4451

45-
| | **Image AUROC** | **Pixel AUPRO** |
46-
| ---------- | :-------------: | :-------------: |
47-
| Bottle | 1.000 | 0.903 |
48-
| Cable | 0.981 | 0.901 |
49-
| Capsule | 0.989 | 0.931 |
50-
| Carpet | 0.985 | 0.929 |
51-
| Grid | 0.994 | 0.930 |
52-
| Hazelnut | 0.994 | 0.943 |
53-
| Leather | 1.000 | 0.970 |
54-
| Metal_nut | 0.995 | 0.920 |
55-
| Pill | 0.962 | 0.936 |
56-
| Screw | 0.912 | 0.947 |
57-
| Tile | 0.994 | 0.854 |
58-
| Toothbrush | 0.908 | 0.860 |
59-
| Transistor | 1.000 | 0.907 |
60-
| Wood | 0.987 | 0.858 |
61-
| Zipper | 0.995 | 0.928 |
62-
| Average | 0.980 | 0.914 |
63-
64-
For other results on VisA, SensumSODF, and KSDD2, refer to the [paper](https://arxiv.org/pdf/2408.03143).
52+
| Category | AUROC (ICPR) | AUROC (JIMS) | AUPRO (ICPR) | AUPRO (JIMS) |
53+
| ----------- | :----------: | :----------: | :----------: | :----------: |
54+
| Bottle | 1.000 | 1.000 | 0.903 | 0.911 |
55+
| Cable | 0.981 | 0.951 | 0.901 | 0.893 |
56+
| Capsule | 0.989 | 0.992 | 0.931 | 0.919 |
57+
| Carpet | 0.985 | 0.974 | 0.929 | 0.935 |
58+
| Grid | 0.994 | 0.998 | 0.930 | 0.938 |
59+
| Hazelnut | 0.994 | 0.999 | 0.943 | 0.939 |
60+
| Leather | 1.000 | 1.000 | 0.970 | 0.974 |
61+
| Metal_nut | 0.995 | 0.993 | 0.920 | 0.925 |
62+
| Pill | 0.962 | 0.980 | 0.936 | 0.943 |
63+
| Screw | 0.912 | 0.854 | 0.947 | 0.946 |
64+
| Tile | 0.994 | 0.992 | 0.854 | 0.825 |
65+
| Toothbrush | 0.908 | 0.908 | 0.860 | 0.854 |
66+
| Transistor | 1.000 | 1.000 | 0.907 | 0.916 |
67+
| Wood | 0.987 | 0.991 | 0.858 | 0.872 |
68+
| Zipper | 0.995 | 0.999 | 0.928 | 0.944 |
69+
| **Average** | **0.980** | **0.975** | **0.914** | **0.916** |
70+
71+
For other results on VisA, SensumSODF, and KSDD2, refer to the [paper](https://link.springer.com/article/10.1007/s10845-025-02680-8).

src/anomalib/models/image/supersimplenet/anomaly_generator.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,34 +88,39 @@ def generate_perlin(self, batches: int, height: int, width: int) -> torch.Tensor
8888

8989
def forward(
9090
self,
91-
features: torch.Tensor,
92-
mask: torch.Tensor,
91+
input_features: torch.Tensor | None,
92+
adapted_features: torch.Tensor,
93+
masks: torch.Tensor,
9394
labels: torch.Tensor,
94-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
95+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
9596
"""Generate anomaly on features using thresholded perlin noise and Gaussian noise.
9697
9798
Also update GT masks and labels with new anomaly information.
9899
99100
Args:
100-
features (torch.Tensor): input features.
101-
mask (torch.Tensor): GT masks.
101+
input_features (torch.Tensor): input features. Set to None if we only need adapted.
102+
adapted_features (torch.Tensor): adapted input features.
103+
masks (torch.Tensor): GT masks.
102104
labels (torch.Tensor): GT labels.
103105
104106
Returns:
105-
perturbed features, updated GT masks and labels.
107+
perturbed features (if not None), perturbed adapted, updated GT masks and labels.
106108
"""
107-
b, _, h, w = features.shape
109+
b, _, h, w = masks.shape
108110

109111
# duplicate
110-
features = torch.cat((features, features))
111-
mask = torch.cat((mask, mask))
112+
adapted_features = torch.cat((adapted_features, adapted_features))
113+
mask = torch.cat((masks, masks))
112114
labels = torch.cat((labels, labels))
115+
# extended ssn case where cls gets non-adapted
116+
if input_features is not None:
117+
input_features = torch.cat((input_features, input_features))
113118

114119
noise = torch.normal(
115120
mean=self.noise_mean,
116121
std=self.noise_std,
117-
size=features.shape,
118-
device=features.device,
122+
size=adapted_features.shape,
123+
device=adapted_features.device,
119124
requires_grad=False,
120125
)
121126

@@ -126,15 +131,15 @@ def forward(
126131
1,
127132
h,
128133
w,
129-
device=features.device,
134+
device=adapted_features.device,
130135
requires_grad=False,
131136
)
132137

133138
# no overlap: don't apply to already anomalous regions (mask=1 -> bad)
134139
noise_mask = noise_mask * (1 - mask)
135140

136141
# shape of noise is [B * 2, 1, H, W]
137-
perlin_mask = self.generate_perlin(b * 2, h, w).to(features.device)
142+
perlin_mask = self.generate_perlin(b * 2, h, w).to(adapted_features.device)
138143
# only apply where perlin mask is 1
139144
noise_mask = noise_mask * perlin_mask
140145

@@ -150,6 +155,7 @@ def forward(
150155
labels = torch.where(labels > 0, torch.ones_like(labels), torch.zeros_like(labels))
151156

152157
# apply masked noise
153-
perturbed = features + noise * noise_mask
158+
perturbed_adapt = adapted_features + noise * noise_mask
159+
perturbed_feat = input_features + noise * noise_mask if input_features is not None else None
154160

155-
return perturbed, mask, labels
161+
return perturbed_feat, perturbed_adapt, mask, labels

src/anomalib/models/image/supersimplenet/lightning_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection.
4+
"""SuperSimpleNet.
5+
6+
ICPR 2024 -
7+
SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection.
8+
9+
JIMS 2025 - No Label Left Behind: A Unified Surface Defect Detection Model for all Supervision Regimes
510
611
This module implements the SuperSimpleNet model for surface defect / anomaly detection.
712
SuperSimpleNet is a simple yet strong discriminative model consisting of a pretrained feature extractor with upscaling,
@@ -25,9 +30,13 @@
2530
2631
2732
Paper:
28-
Title: SuperSimpleNet: Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection.
33+
Original: SuperSimpleNet:
34+
Unifying Unsupervised and Supervised Learning for Fast and Reliable Surface Defect Detection.
2935
URL: https://arxiv.org/pdf/2408.03143
3036
37+
Extension: No label left behind: a unified surface defect detection model for all supervision regimes
38+
URL: https://link.springer.com/article/10.1007/s10845-025-02680-8
39+
3140
Notes:
3241
This implementation supports both unsupervised and supervised setting,
3342
but Anomalib currently supports only unsupervised learning.
@@ -64,6 +73,7 @@ class Supersimplenet(AnomalibModule):
6473
backbone (str): backbone name. IMPORTANT! use only backbones with torchvision V1 weights ending on ".tv".
6574
layers (list[str]): backbone layers utilised
6675
supervised (bool): whether the model will be trained in supervised mode. False by default (unsupervised).
76+
adapt_cls_features (bool): whether to adapt classification features (ICPR - True, JIMS - False (default)).
6777
pre_processor (PreProcessor | bool, optional): Pre-processor instance or
6878
flag to use default. Defaults to ``True``.
6979
post_processor (PostProcessor | bool, optional): Post-processor instance
@@ -80,6 +90,7 @@ def __init__(
8090
backbone: str = "wide_resnet50_2.tv_in1k", # IMPORTANT: use .tv weights, not tv2
8191
layers: list[str] = ["layer2", "layer3"], # noqa: B006
8292
supervised: bool = False,
93+
adapt_cls_features: bool = False,
8394
pre_processor: PreProcessor | bool = True,
8495
post_processor: PostProcessor | bool = True,
8596
evaluator: Evaluator | bool = True,
@@ -105,6 +116,7 @@ def __init__(
105116
backbone=backbone,
106117
layers=layers,
107118
stop_grad=stop_grad,
119+
adapt_cls_features=adapt_cls_features,
108120
)
109121
self.loss = SSNLoss()
110122

src/anomalib/models/image/supersimplenet/torch_model.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# SPDX-License-Identifier: MIT
55
#
66
# Modified
7-
# Copyright (C) 2024 Intel Corporation
7+
# Copyright (C) 2024-2025 Intel Corporation
88
# SPDX-License-Identifier: Apache-2.0
99

1010
"""PyTorch model for the SuperSimpleNet model implementation.
@@ -19,7 +19,6 @@
1919
import torch
2020
import torch.nn.functional as F # noqa: N812
2121
from torch import nn
22-
from torch.nn import Parameter
2322

2423
from anomalib.data import InferenceBatch
2524
from anomalib.models.components import GaussianBlur2d, TimmFeatureExtractor
@@ -36,6 +35,7 @@ class SupersimplenetModel(nn.Module):
3635
backbone (str): backbone name. IMPORTANT! use only backbones with torchvision V1 weights ending on ".tv".
3736
layers (list[str]): backbone layers utilised
3837
stop_grad (bool): whether to stop gradient from class. to seg. head.
38+
adapt_cls_features (bool): whether to adapt classification features (ICPR - True, JIMS - False (default)).
3939
"""
4040

4141
def __init__(
@@ -44,11 +44,13 @@ def __init__(
4444
backbone: str = "wide_resnet50_2.tv_in1k", # IMPORTANT: use .tv weights, not tv2
4545
layers: list[str] = ["layer2", "layer3"], # noqa: B006
4646
stop_grad: bool = True,
47+
adapt_cls_features: bool = False,
4748
) -> None:
4849
super().__init__()
4950
self.feature_extractor = UpscalingFeatureExtractor(backbone=backbone, layers=layers)
5051

5152
channels = self.feature_extractor.get_channels_dim()
53+
self.adapt_cls_features = adapt_cls_features
5254
self.adaptor = FeatureAdapter(channels)
5355
self.segdec = SegmentationDetectionModule(channel_dim=channels, stop_grad=stop_grad)
5456
self.anomaly_generator = AnomalyGenerator(noise_mean=0, noise_std=0.015, threshold=perlin_threshold)
@@ -80,23 +82,52 @@ def forward(
8082
adapted = self.adaptor(features)
8183

8284
if self.training:
83-
masks = self.downsample_mask(masks, *features.shape[-2:])
85+
if masks is None:
86+
if labels is not None and labels.any():
87+
msg = "Training with anomalous samples without GT masks is currently not supported!"
88+
raise RuntimeError(msg)
89+
b, _, h, w = features.shape
90+
masks = torch.zeros((b, 1, h, w), dtype=torch.float32, device=features.device)
91+
else:
92+
masks = self.downsample_mask(masks, *features.shape[-2:])
8493
# make linter happy :)
8594
if labels is not None:
8695
labels = labels.type(torch.float32)
8796

88-
features, masks, labels = self.anomaly_generator(
89-
adapted,
90-
masks,
91-
labels,
92-
)
93-
94-
anomaly_map, anomaly_score = self.segdec(features)
97+
if self.adapt_cls_features:
98+
# ICPR SuperSimpleNet - add noise to adapted only (since non-adapted are not used)
99+
_, noised_adapt, masks, labels = self.anomaly_generator(
100+
input_features=None,
101+
adapted_features=adapted,
102+
masks=masks,
103+
labels=labels,
104+
)
105+
seg_feats = noised_adapt
106+
cls_feats = noised_adapt
107+
else:
108+
# extension of SuperSimpleNet - add (same) noise to adapted and features
109+
noised_feat, noised_adapt, masks, labels = self.anomaly_generator(
110+
input_features=features,
111+
adapted_features=adapted,
112+
masks=masks,
113+
labels=labels,
114+
)
115+
seg_feats = noised_adapt
116+
cls_feats = noised_feat
117+
118+
anomaly_map, anomaly_score = self.segdec(seg_features=seg_feats, cls_features=cls_feats)
95119
return anomaly_map, anomaly_score, masks, labels
96120

97-
anomaly_map, anomaly_score = self.segdec(adapted)
121+
seg_feats = adapted
122+
# ICPR SuperSimpleNet - cls and seg both use adapted feat, JIMS extension SuperSimpleNet - adapt only seg feats
123+
cls_feats = adapted if self.adapt_cls_features else features
124+
125+
anomaly_map, anomaly_score = self.segdec(seg_features=seg_feats, cls_features=cls_feats)
98126
anomaly_map = self.anomaly_map_generator(anomaly_map, final_size=output_size)
99127

128+
anomaly_score = anomaly_score.sigmoid()
129+
anomaly_map = anomaly_map.sigmoid()
130+
100131
return InferenceBatch(anomaly_map=anomaly_map, pred_score=anomaly_score)
101132

102133
@staticmethod
@@ -296,33 +327,24 @@ def __init__(
296327

297328
self.apply(init_weights)
298329

299-
def get_params(self) -> tuple[list[Parameter], list[Parameter]]:
300-
"""Get segmentation and classification head parameters.
301-
302-
Returns:
303-
seg. head parameters and class. head parameters.
304-
"""
305-
seg_params = list(self.seg_head.parameters())
306-
dec_params = list(self.cls_conv.parameters()) + list(self.cls_fc.parameters())
307-
return seg_params, dec_params
308-
309-
def forward(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
330+
def forward(self, seg_features: torch.Tensor, cls_features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
310331
"""Predict anomaly map and anomaly score.
311332
312333
Args:
313-
features: adapted features.
334+
seg_features: segmentation head features.
335+
cls_features: classification head features.
314336
315337
Returns:
316338
predicted anomaly map and score.
317339
"""
318340
# get anomaly map from seg head
319-
ano_map = self.seg_head(features)
341+
ano_map = self.seg_head(seg_features)
320342

321343
map_dec_copy = ano_map
322344
if self.stop_grad:
323345
map_dec_copy = map_dec_copy.detach()
324346
# dec conv layer takes feat + map
325-
mask_cat = torch.cat((features, map_dec_copy), dim=1)
347+
mask_cat = torch.cat((cls_features, map_dec_copy), dim=1)
326348
dec_out = self.cls_conv(mask_cat)
327349

328350
# conv block result pooling
@@ -340,7 +362,7 @@ def forward(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
340362

341363
# final dec layer: conv channel max and avg and map max and avg
342364
dec_cat = torch.cat((dec_max, dec_avg, map_max, map_avg), dim=1).squeeze()
343-
ano_score = self.cls_fc(dec_cat).squeeze()
365+
ano_score = self.cls_fc(dec_cat).reshape(-1)
344366

345367
return ano_map, ano_score
346368

0 commit comments

Comments
 (0)