Skip to content

Commit f762e41

Browse files
authored
🚀 feat(model): add automatic download of DTD dataset to DRAEM model (#2866)
Added automatic download of DTD dataset to DRAEM model
1 parent f8f14ce commit f762e41

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

‎docs/source/examples‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
../../examples
1+
../../examples

‎examples/configs/model/draem.yaml‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ model:
44
beta: [0.1, 1.0]
55
enable_sspcab: false
66
sspcab_lambda: 0.1
7-
anomaly_source_path: null
7+
dtd_dir: ./datasets/dtd
88

99
trainer:
1010
max_epochs: 700

‎src/anomalib/data/utils/generators/perlin.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class PerlinAnomalyGenerator(v2.Transform):
200200

201201
def __init__(
202202
self,
203-
anomaly_source_path: str | None = None,
203+
anomaly_source_path: Path | str | None = None,
204204
probability: float = 0.5,
205205
blend_factor: float | tuple[float, float] = (0.2, 1.0),
206206
rotation_range: tuple[float, float] = (-90, 90),

‎src/anomalib/models/image/draem/README.md‎

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ Model Type: Segmentation
88

99
DRAEM is a reconstruction based algorithm that consists of a reconstructive subnetwork and a discriminative subnetwork. DRAEM is trained on simulated anomaly images, generated by augmenting normal input images from the training set with a random Perlin noise mask extracted from an unrelated source of image data. The reconstructive subnetwork is an autoencoder architecture that is trained to reconstruct the original input images from the augmented images. The reconstructive submodel is trained using a combination of L2 loss and Structural Similarity loss. The input of the discriminative subnetwork consists of the channel-wise concatenation of the (augmented) input image and the output of the reconstructive subnetwork. The output of the discriminative subnetwork is an anomaly map that contains the predicted anomaly scores for each pixel location. The discriminative subnetwork is trained using Focal Loss.
1010

11-
For optimal results, DRAEM requires specifying the path to a folder of image data that will be used as the source of the anomalous pixel regions in the simulated anomaly images. The path can be specified by editing the value of the `model.anomaly_source_path` parameter in the `config.yaml` file. The authors of the original paper recommend using the [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset as anomaly source.
12-
1311
## Architecture
1412

1513
![DRAEM Architecture](/docs/source/images/draem/architecture.png "DRAEM Architecture")

‎src/anomalib/models/image/draem/lightning_model.py‎

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
from collections.abc import Callable
17+
from pathlib import Path
1718
from typing import Any
1819

1920
import torch
@@ -23,6 +24,7 @@
2324

2425
from anomalib import LearningType
2526
from anomalib.data import Batch
27+
from anomalib.data.utils import DownloadInfo, download_and_extract
2628
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
2729
from anomalib.metrics import Evaluator
2830
from anomalib.models.components import AnomalibModule
@@ -35,6 +37,12 @@
3537

3638
__all__ = ["Draem"]
3739

40+
DTD_DOWNLOAD_INFO = DownloadInfo(
41+
name="dtd-r1.0.1.tar.gz",
42+
url="https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz",
43+
hashsum="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205",
44+
)
45+
3846

3947
class Draem(AnomalibModule):
4048
"""DRÆM.
@@ -47,6 +55,8 @@ class Draem(AnomalibModule):
4755
2. A discriminative network that learns to identify anomalous regions
4856
4957
Args:
58+
dtd_dir (Path | str): Directory path for the DTD dataset for anomaly deneration.
59+
Defaults to ``./datasets/dtd``.
5060
enable_sspcab (bool, optional): Enable SSPCAB training.
5161
Defaults to ``False``.
5262
sspcab_lambda (float, optional): Weight factor for SSPCAB loss.
@@ -73,9 +83,9 @@ class Draem(AnomalibModule):
7383

7484
def __init__(
7585
self,
86+
dtd_dir: Path | str = "./datasets/dtd",
7687
enable_sspcab: bool = False,
7788
sspcab_lambda: float = 0.1,
78-
anomaly_source_path: str | None = None,
7989
beta: float | tuple[float, float] = (0.1, 1.0),
8090
pre_processor: PreProcessor | bool = True,
8191
post_processor: PostProcessor | bool = True,
@@ -88,8 +98,10 @@ def __init__(
8898
evaluator=evaluator,
8999
visualizer=visualizer,
90100
)
91-
92-
self.augmenter = PerlinAnomalyGenerator(anomaly_source_path=anomaly_source_path, blend_factor=beta)
101+
dtd_dir = Path(dtd_dir)
102+
if not dtd_dir.is_dir():
103+
download_and_extract(dtd_dir, DTD_DOWNLOAD_INFO)
104+
self.augmenter = PerlinAnomalyGenerator(anomaly_source_path=dtd_dir, blend_factor=beta)
93105
self.model = DraemModel(sspcab=enable_sspcab)
94106
self.loss = DraemLoss()
95107
self.sspcab = enable_sspcab

0 commit comments

Comments
 (0)