|
1 | | -# Copyright (C) 2023-2024 Intel Corporation |
| 1 | +# Copyright (C) 2023-2025 Intel Corporation |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | 4 | """DSR - A Dual Subspace Re-Projection Network for Surface Anomaly Detection. |
|
39 | 39 |
|
40 | 40 | import torch |
41 | 41 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler |
42 | | -from torchvision.transforms.v2 import Compose, Resize |
| 42 | +from torchvision.transforms.v2 import Compose, Normalize, Resize |
43 | 43 |
|
44 | 44 | from anomalib import LearningType |
45 | 45 | from anomalib.data import Batch |
| 46 | +from anomalib.data.transforms.utils import extract_transforms_by_type |
46 | 47 | from anomalib.data.utils import DownloadInfo, download_and_extract |
47 | 48 | from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator |
48 | 49 | from anomalib.metrics import Evaluator |
@@ -182,7 +183,19 @@ def configure_optimizers( |
182 | 183 | return ({"optimizer": optimizer_d, "lr_scheduler": scheduler_d}, {"optimizer": optimizer_u}) |
183 | 184 |
|
184 | 185 | def on_train_start(self) -> None: |
185 | | - """Load pretrained weights of the discrete model when starting training.""" |
| 186 | + """Set up model before training begins. |
| 187 | +
|
| 188 | + Performs the following steps: |
| 189 | + 1. Validates that pre_processor uses no normalization |
| 190 | + 2. Load pretrained weights of the discrete model |
| 191 | +
|
| 192 | + Raises: |
| 193 | + ValueError: If transforms contain normalization. |
| 194 | + """ |
| 195 | + if self.pre_processor and extract_transforms_by_type(self.pre_processor.transform, Normalize): |
| 196 | + msg = "Transforms for DSR should not contain Normalize." |
| 197 | + raise ValueError(msg) |
| 198 | + |
186 | 199 | ckpt: Path = self.prepare_pretrained_model() |
187 | 200 | self.model.load_pretrained_discrete_model_weights(ckpt, self.device) |
188 | 201 |
|
|
0 commit comments