Skip to content

Commit af32ad6

Browse files
🚀 feat(model): add UniNet (#2797)
* stash changes Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Add initial code Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Working model Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Working model Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Fix pre-commit Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Minor refactor Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> * Refactor components Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * single line Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * de_resnet->resnet_decoder Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * bug fixes Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * openvino export + pr comments Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * Add documentation Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * Remove else branch Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * Move copyright Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> * PR comments - Add comments in AttentionBlock - Refacor torch module Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com> --------- Signed-off-by: Ashwin Vaidya <ashwinnitinvaidya@gmail.com> Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
1 parent dfbffe2 commit af32ad6

File tree

19 files changed

+1294
-32
lines changed

19 files changed

+1294
-32
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# UniNet
2+
3+
```{eval-rst}
4+
.. automodule:: anomalib.models.image.uninet.lightning_model
5+
:members:
6+
:show-inheritance:
7+
```
8+
9+
```{eval-rst}
10+
.. automodule:: anomalib.models.image.uninet.torch_model
11+
:members:
12+
:show-inheritance:
13+
```
14+
15+
```{eval-rst}
16+
.. automodule:: anomalib.models.image.uninet.components.loss
17+
:members:
18+
:show-inheritance:
19+
```
20+
21+
```{eval-rst}
22+
.. automodule:: anomalib.models.image.uninet.components.anomaly_map
23+
:members:
24+
:show-inheritance:
25+
```
26+
27+
```{eval-rst}
28+
.. automodule:: anomalib.models.image.uninet.components.attention_bottleneck
29+
:members:
30+
:show-inheritance:
31+
```
32+
33+
```{eval-rst}
34+
.. automodule:: anomalib.models.image.uninet.components.dfs
35+
:members:
36+
:show-inheritance:
37+
```

examples/configs/model/uninet.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
model:
2+
class_path: anomalib.models.UniNet
3+
init_args:
4+
student_backbone: wide_resnet50_2
5+
teacher_backbone: wide_resnet50_2
6+
temperature: 0.1
7+
8+
trainer:
9+
max_epochs: 100
10+
callbacks:
11+
- class_path: lightning.pytorch.callbacks.EarlyStopping
12+
init_args:
13+
patience: 20
14+
monitor: image_AUROC
15+
mode: max

src/anomalib/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
Stfpm,
7474
Supersimplenet,
7575
Uflow,
76+
UniNet,
7677
VlmAd,
7778
WinClip,
7879
)
@@ -110,6 +111,7 @@ class UnknownModelError(ModuleNotFoundError):
110111
"Stfpm",
111112
"Supersimplenet",
112113
"Uflow",
114+
"UniNet",
113115
"VlmAd",
114116
"WinClip",
115117
"AiVad",
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Common backbone models.
2+
3+
Example:
4+
>>> from anomalib.models.components.backbone import get_decoder
5+
>>> decoder = get_decoder()
6+
7+
See Also:
8+
- :func:`anomalib.models.components.backbone.de_resnet`:
9+
Decoder network implementation
10+
"""
11+
12+
# Copyright (C) 2025 Intel Corporation
13+
# SPDX-License-Identifier: Apache-2.0
14+
15+
from .resnet_decoder import get_decoder
16+
17+
__all__ = ["get_decoder"]

src/anomalib/models/image/reverse_distillation/components/de_resnet.py renamed to src/anomalib/models/components/backbone/resnet_decoder.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
# Copyright (C) 2022-2025 Intel Corporation
88
# SPDX-License-Identifier: Apache-2.0
99

10-
"""PyTorch model defining the decoder network for Reverse Distillation.
10+
"""PyTorch model defining the decoder network for Reverse Distillation and UniNet.
1111
12-
This module implements the decoder network used in the Reverse Distillation model
12+
This module implements the decoder network used in the Reverse Distillation and UniNet models
1313
architecture. The decoder reconstructs features from the bottleneck representation
1414
back to the original feature space.
1515
@@ -19,17 +19,15 @@
1919
- Full decoder network architecture
2020
2121
Example:
22-
>>> from anomalib.models.image.reverse_distillation.components.de_resnet import (
22+
>>> from anomalib.models.components.backbone.resnet_decoder import (
2323
... get_decoder
2424
... )
2525
>>> decoder = get_decoder()
2626
>>> features = torch.randn(32, 512, 28, 28)
2727
>>> reconstructed = decoder(features)
2828
2929
See Also:
30-
- :class:`anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel`:
31-
Main model implementation using this decoder
32-
- :class:`anomalib.models.image.reverse_distillation.components.DecoderBasicBlock`:
30+
- :class:`anomalib.models.components.backbone.resnet_decoder.DecoderBasicBlock`:
3331
Basic building block for the decoder network
3432
"""
3533

@@ -157,8 +155,8 @@ class DecoderBottleneck(nn.Module):
157155
"""Bottleneck block for the decoder network.
158156
159157
This module implements a bottleneck block used in the decoder part of the Reverse
160-
Distillation model. It performs upsampling and feature reconstruction through a series of
161-
convolutional layers.
158+
Distillation and UniNet models. It performs upsampling and feature reconstruction
159+
through a series of convolutional layers.
162160
163161
The block consists of three convolution layers:
164162
1. 1x1 conv to adjust channels
@@ -186,7 +184,7 @@ class DecoderBottleneck(nn.Module):
186184
187185
Example:
188186
>>> import torch
189-
>>> from anomalib.models.image.reverse_distillation.components.de_resnet import (
187+
>>> from anomalib.models.components.backbone.resnet_decoder import (
190188
... DecoderBottleneck
191189
... )
192190
>>> layer = DecoderBottleneck(256, 64)
@@ -269,7 +267,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor:
269267
return self.relu(out)
270268

271269

272-
class ResNet(nn.Module):
270+
class ResNetDecoder(nn.Module):
273271
"""Decoder ResNet model for feature reconstruction.
274272
275273
This module implements a decoder version of the ResNet architecture, which
@@ -297,11 +295,11 @@ class ResNet(nn.Module):
297295
layer to use. If ``None``, uses ``BatchNorm2d``. Defaults to ``None``.
298296
299297
Example:
300-
>>> from anomalib.models.image.reverse_distillation.components import (
298+
>>> from anomalib.models.components.backbone.resnet_decoder import (
301299
... DecoderBasicBlock,
302-
... ResNet
300+
... ResNetDecoder
303301
... )
304-
>>> model = ResNet(
302+
>>> model = ResNetDecoder(
305303
... block=DecoderBasicBlock,
306304
... layers=[2, 2, 2, 2]
307305
... )
@@ -437,63 +435,63 @@ def forward(self, batch: torch.Tensor) -> list[torch.Tensor]:
437435
return [feature_c, feature_b, feature_a]
438436

439437

440-
def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNet:
441-
return ResNet(block, layers, **kwargs)
438+
def _resnet(block: type[DecoderBasicBlock | DecoderBottleneck], layers: list[int], **kwargs) -> ResNetDecoder:
439+
return ResNetDecoder(block, layers, **kwargs)
442440

443441

444-
def de_resnet18() -> ResNet:
442+
def de_resnet18() -> ResNetDecoder:
445443
"""ResNet-18 model."""
446444
return _resnet(DecoderBasicBlock, [2, 2, 2, 2])
447445

448446

449-
def de_resnet34() -> ResNet:
447+
def de_resnet34() -> ResNetDecoder:
450448
"""ResNet-34 model."""
451449
return _resnet(DecoderBasicBlock, [3, 4, 6, 3])
452450

453451

454-
def de_resnet50() -> ResNet:
452+
def de_resnet50() -> ResNetDecoder:
455453
"""ResNet-50 model."""
456454
return _resnet(DecoderBottleneck, [3, 4, 6, 3])
457455

458456

459-
def de_resnet101() -> ResNet:
457+
def de_resnet101() -> ResNetDecoder:
460458
"""ResNet-101 model."""
461459
return _resnet(DecoderBottleneck, [3, 4, 23, 3])
462460

463461

464-
def de_resnet152() -> ResNet:
462+
def de_resnet152() -> ResNetDecoder:
465463
"""ResNet-152 model."""
466464
return _resnet(DecoderBottleneck, [3, 8, 36, 3])
467465

468466

469-
def de_resnext50_32x4d() -> ResNet:
467+
def de_resnext50_32x4d() -> ResNetDecoder:
470468
"""ResNeXt-50 32x4d model."""
471469
return _resnet(DecoderBottleneck, [3, 4, 6, 3], groups=32, width_per_group=4)
472470

473471

474-
def de_resnext101_32x8d() -> ResNet:
472+
def de_resnext101_32x8d() -> ResNetDecoder:
475473
"""ResNeXt-101 32x8d model."""
476474
return _resnet(DecoderBottleneck, [3, 4, 23, 3], groups=32, width_per_group=8)
477475

478476

479-
def de_wide_resnet50_2() -> ResNet:
477+
def de_wide_resnet50_2() -> ResNetDecoder:
480478
"""Wide ResNet-50-2 model."""
481479
return _resnet(DecoderBottleneck, [3, 4, 6, 3], width_per_group=128)
482480

483481

484-
def de_wide_resnet101_2() -> ResNet:
482+
def de_wide_resnet101_2() -> ResNetDecoder:
485483
"""Wide ResNet-101-2 model."""
486484
return _resnet(DecoderBottleneck, [3, 4, 23, 3], width_per_group=128)
487485

488486

489-
def get_decoder(name: str) -> ResNet:
487+
def get_decoder(name: str) -> ResNetDecoder:
490488
"""Get decoder model based on the name of the backbone.
491489
492490
Args:
493491
name (str): Name of the backbone.
494492
495493
Returns:
496-
ResNet: Decoder ResNet architecture.
494+
ResNetDecoder: Decoder ResNet architecture.
497495
"""
498496
decoder_map = {
499497
"resnet18": de_resnet18,

src/anomalib/models/image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .stfpm import Stfpm
6363
from .supersimplenet import Supersimplenet
6464
from .uflow import Uflow
65+
from .uninet import UniNet
6566
from .vlm_ad import VlmAd
6667
from .winclip import WinClip
6768

@@ -83,6 +84,7 @@
8384
"Stfpm",
8485
"Supersimplenet",
8586
"Uflow",
87+
"UniNet",
8688
"VlmAd",
8789
"WinClip",
8890
"Dinomaly",

src/anomalib/models/image/reverse_distillation/components/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
through distillation and reconstruction:
1111
1212
- Bottleneck layer: Compresses features into a lower dimensional space
13-
- Decoder network: Reconstructs features from the bottleneck representation
1413
1514
Example:
1615
>>> from anomalib.models.image.reverse_distillation.components import (
@@ -23,11 +22,8 @@
2322
See Also:
2423
- :func:`anomalib.models.image.reverse_distillation.components.bottleneck`:
2524
Bottleneck layer implementation
26-
- :func:`anomalib.models.image.reverse_distillation.components.de_resnet`:
27-
Decoder network implementation
2825
"""
2926

3027
from .bottleneck import get_bottleneck_layer
31-
from .de_resnet import get_decoder
3228

33-
__all__ = ["get_bottleneck_layer", "get_decoder"]
29+
__all__ = ["get_bottleneck_layer"]

src/anomalib/models/image/reverse_distillation/torch_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939

4040
from anomalib.data import InferenceBatch
4141
from anomalib.models.components import TimmFeatureExtractor
42+
from anomalib.models.components.backbone import get_decoder
4243

4344
from .anomaly_map import AnomalyMapGenerationMode, AnomalyMapGenerator
44-
from .components import get_bottleneck_layer, get_decoder
45+
from .components import get_bottleneck_layer
4546

4647
if TYPE_CHECKING:
4748
from anomalib.data.utils.tiler import Tiler
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
Copyright (c) 2025 Intel Corporation
2+
SPDX-License-Identifier: Apache-2.0
3+
4+
Some files in this folder are based on the original UniNet implementation by Shun Wei, Jielin Jiang, and Xiaolong Xu.
5+
6+
Original license:
7+
-----------------
8+
9+
MIT License
10+
11+
Copyright (c) 2025 Shun Wei
12+
13+
Permission is hereby granted, free of charge, to any person obtaining a copy
14+
of this software and associated documentation files (the "Software"), to deal
15+
in the Software without restriction, including without limitation the rights
16+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17+
copies of the Software, and to permit persons to whom the Software is
18+
furnished to do so, subject to the following conditions:
19+
20+
The above copyright notice and this permission notice shall be included in all
21+
copies or substantial portions of the Software.
22+
23+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29+
SOFTWARE.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# UniNet: Unified Contrastive Learning Framework for Anomaly Detection
2+
3+
This is the implementation of the UniNet model, a unified contrastive learning framework for anomaly detection presented in CVPR 2025.
4+
5+
Model Type: Classification and Segmentation
6+
7+
## Description
8+
9+
UniNet is a contrastive learning-based anomaly detection model that uses a teacher-student architecture with attention bottleneck mechanisms. The model is designed for diverse domains and supports both supervised and unsupervised anomaly detection scenarios. It focuses on multi-class anomaly detection and leverages domain-related feature selection to improve performance across different categories.
10+
11+
The model consists of:
12+
13+
- **Teacher Networks**: Pre-trained backbone networks that provide reference features for normal samples
14+
- **Student Network**: A decoder network that learns to reconstruct normal patterns
15+
- **Attention Bottleneck**: Mechanisms that help focus on relevant features
16+
- **Domain-Related Feature Selection**: Adaptive feature selection for different domains
17+
- **Contrastive Loss**: Temperature-controlled similarity computation between student and teacher features
18+
19+
During training, the student network learns to match teacher features for normal samples while being trained to distinguish anomalous patterns through contrastive learning. The model uses a weighted decision mechanism during inference to combine multi-scale features for final anomaly scoring.
20+
21+
## Architecture
22+
23+
The UniNet architecture leverages contrastive learning with teacher-student networks, incorporating attention bottlenecks and domain-specific feature selection for robust anomaly detection across diverse domains.
24+
25+
## Usage
26+
27+
`anomalib train --model UniNet --data MVTecAD --data.category <category>`
28+
29+
## Benchmark
30+
31+
All results gathered with seed `42`.
32+
33+
## [MVTecAD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)
34+
35+
### Image-Level AUC
36+
37+
| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
38+
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
39+
| 0.956 | 0.999 | 0.982 | 0.939 | 0.896 | 0.996 | 0.999 | 1.000 | 1.000 | 0.816 | 0.919 | 0.970 | 1.000 | 0.984 | 0.993 | 0.945 |
40+
41+
### Pixel-Level AUC
42+
43+
| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
44+
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
45+
| 0.976 | 0.989 | 0.983 | 0.985 | 0.973 | 0.992 | 0.987 | 0.993 | 0.984 | 0.964 | 0.992 | 0.965 | 0.992 | 0.923 | 0.961 | 0.984 |
46+
47+
### Image F1 Score
48+
49+
| Avg | Bottle | Cable | Capsule | Carpet | Grid | Hazelnut | Leather | Metal Nut | Pill | Screw | Tile | Toothbrush | Transistor | Wood | Zipper |
50+
| :---: | :----: | :---: | :-----: | :----: | :---: | :------: | :-----: | :-------: | :---: | :---: | :---: | :--------: | :--------: | :---: | :----: |
51+
| 0.957 | 0.984 | 0.944 | 0.964 | 0.883 | 0.973 | 0.986 | 0.995 | 0.989 | 0.921 | 0.905 | 0.946 | 0.983 | 0.961 | 0.974 | 0.959 |
52+
53+
## Model Features
54+
55+
- **Contrastive Learning**: Uses temperature-controlled contrastive loss for effective anomaly detection
56+
- **Multi-Domain Support**: Designed to work across diverse domains with domain-related feature selection
57+
- **Flexible Training**: Supports both supervised and unsupervised training modes
58+
- **Attention Mechanisms**: Incorporates attention bottlenecks for focused feature learning
59+
- **Multi-Scale Features**: Leverages multi-scale feature matching for robust detection
60+
61+
## Parameters
62+
63+
- `student_backbone`: Backbone model for student network (default: "wide_resnet50_2")
64+
- `teacher_backbone`: Backbone model for teacher network (default: "wide_resnet50_2")
65+
- `temperature`: Temperature parameter for contrastive loss (default: 0.1)

0 commit comments

Comments
 (0)