Skip to content

Commit 3412e9d

Browse files
fix(deploy): Fix discrepancy between lightning and standalone inferencers (#2843)
* Add test + modify anomalib module Signed-off-by: Ashwin Vaidya <ashwin.vaidya@intel.com>
1 parent e4aa9b2 commit 3412e9d

File tree

6 files changed

+132
-17
lines changed

6 files changed

+132
-17
lines changed

src/anomalib/models/components/base/anomalib_module.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,9 @@ def predict_step(
208208
) -> STEP_OUTPUT:
209209
"""Perform prediction step.
210210
211-
This method is called during the predict stage of training. By default,
212-
it calls the validation step.
211+
This method is called during the predict stage of training. It calls
212+
the model's forward method to ensure consistency with exported model behavior,
213+
then merges the predictions into the batch for post-processing.
213214
214215
Args:
215216
batch (Batch): Input batch
@@ -218,17 +219,19 @@ def predict_step(
218219
Defaults to ``0``.
219220
220221
Returns:
221-
STEP_OUTPUT: Model predictions
222+
STEP_OUTPUT: Updated batch with model predictions
222223
"""
223-
del dataloader_idx # These variables are not used.
224+
del dataloader_idx, batch_idx # These variables are not used.
224225

225-
return self.validation_step(batch, batch_idx)
226+
predictions = self.model(batch.image)
227+
return batch.update(**predictions._asdict())
226228

227229
def test_step(self, batch: Batch, batch_idx: int, *args, **kwargs) -> STEP_OUTPUT:
228230
"""Perform test step.
229231
230-
This method is called during the test stage of training. By default,
231-
it calls the predict step.
232+
This method is called during the test stage of training. It calls
233+
the model's forward method to ensure consistency with exported model behavior,
234+
then merges the predictions into the batch for post-processing.
232235
233236
Args:
234237
batch (Batch): Input batch
@@ -237,11 +240,12 @@ def test_step(self, batch: Batch, batch_idx: int, *args, **kwargs) -> STEP_OUTPU
237240
**kwargs: Additional keyword arguments (unused)
238241
239242
Returns:
240-
STEP_OUTPUT: Model predictions
243+
STEP_OUTPUT: Updated batch with model predictions
241244
"""
242-
del args, kwargs # These variables are not used.
245+
del args, kwargs, batch_idx # These variables are not used.
243246

244-
return self.predict_step(batch, batch_idx)
247+
predictions = self.model(batch.image)
248+
return batch.update(**predictions._asdict())
245249

246250
@property
247251
@abstractmethod

src/anomalib/models/image/fre/lightning_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,9 @@ def trainer_arguments(self) -> dict[str, Any]:
185185
Returns:
186186
dict[str, Any]: Dictionary of trainer arguments:
187187
- ``gradient_clip_val``: ``0``
188-
- ``max_epochs``: ``220``
189188
- ``num_sanity_val_steps``: ``0``
190189
"""
191-
return {"gradient_clip_val": 0, "max_epochs": 220, "num_sanity_val_steps": 0}
190+
return {"gradient_clip_val": 0, "num_sanity_val_steps": 0}
192191

193192
@property
194193
def learning_type(self) -> LearningType:

src/anomalib/models/image/vlm_ad/lightning_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ def validation_step(self, batch: ImageBatch, *args, **kwargs) -> ImageBatch:
168168
batch.pred_label = torch.tensor([1.0 if r.startswith("Y") else 0.0 for r in responses], device=self.device)
169169
return batch
170170

171+
def test_step(self, batch: ImageBatch, *args, **kwargs) -> ImageBatch: # type: ignore[override]
172+
"""Redirect to validation step."""
173+
return self.validation_step(batch, *args, **kwargs)
174+
175+
def predict_step(self, batch: ImageBatch, *args, **kwargs) -> ImageBatch: # type: ignore[override]
176+
"""Redirect to validation step."""
177+
return self.validation_step(batch, *args, **kwargs)
178+
171179
@property
172180
def learning_type(self) -> LearningType:
173181
"""Get the learning type of the model.

src/anomalib/post_processing/post_processor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,11 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch:
210210
if predictions.pred_score is None and predictions.anomaly_map is None:
211211
msg = "At least one of pred_score or anomaly_map must be provided."
212212
raise ValueError(msg)
213-
pred_score = predictions.pred_score or torch.amax(predictions.anomaly_map, dim=(-2, -1))
213+
pred_score = (
214+
predictions.pred_score
215+
if predictions.pred_score is not None
216+
else torch.amax(predictions.anomaly_map, dim=(-2, -1))
217+
)
214218

215219
if self.enable_normalization:
216220
pred_score = self._normalize(pred_score, self.image_min, self.image_max, self.image_threshold)

tests/integration/test_task_types.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99

1010
import pytest
1111
import torch
12+
from torch import nn
1213
from torchmetrics import Metric
1314

1415
from anomalib import LearningType
15-
from anomalib.data import AnomalibDataModule, Batch, Folder, ImageDataFormat
16+
from anomalib.data import AnomalibDataModule, Batch, Folder, ImageDataFormat, InferenceBatch
1617
from anomalib.engine import Engine
1718
from anomalib.metrics import AnomalibMetric, Evaluator
1819
from anomalib.models import AnomalibModule
@@ -21,13 +22,29 @@
2122
from tests.helpers.data import DummyImageDatasetGenerator
2223

2324

25+
class _DummyModel(nn.Module):
26+
"""Dummy model for testing."""
27+
28+
@staticmethod
29+
def forward(image_tensor: torch.Tensor) -> InferenceBatch:
30+
"""Dummy forward pass."""
31+
return InferenceBatch(
32+
pred_score=torch.rand(image_tensor.shape[0], device=image_tensor.device),
33+
anomaly_map=torch.rand(image_tensor.shape[0], *image_tensor.shape[-2:], device=image_tensor.device),
34+
)
35+
36+
2437
class DummyBaseModel(AnomalibModule):
2538
"""Dummy model for testing.
2639
2740
No training, and all auxiliary components default to None. This allows testing of the different components
2841
in isolation.
2942
"""
3043

44+
def __init__(self, *args, **kwargs) -> None:
45+
super().__init__(*args, **kwargs)
46+
self.model = _DummyModel()
47+
3148
def training_step(self, *args, **kwargs) -> None:
3249
"""Dummy training step."""
3350

@@ -66,7 +83,7 @@ class DummyClassificationModel(DummyBaseModel):
6683
def validation_step(self, batch: Batch, *args, **kwargs) -> Batch:
6784
"""Validation steps that returns random image-level scores."""
6885
del args, kwargs
69-
batch.pred_score = torch.rand(batch.batch_size, device=self.device)
86+
batch.pred_score = self.model(batch.image).pred_score
7087
return batch
7188

7289

@@ -79,8 +96,9 @@ class DummySegmentationModel(DummyBaseModel):
7996
def validation_step(self, batch: Batch, *args, **kwargs) -> Batch:
8097
"""Validation steps that returns random image- and pixel-level scores."""
8198
del args, kwargs
82-
batch.pred_score = torch.rand(batch.batch_size, device=self.device)
83-
batch.anomaly_map = torch.rand(batch.batch_size, *batch.image.shape[-2:], device=self.device)
99+
result = self.model(batch.image)
100+
batch.pred_score = result.pred_score
101+
batch.anomaly_map = result.anomaly_map
84102
return batch
85103

86104

tests/unit/deploy/test_inferencer.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import numpy as np
1010
import pytest
1111
import torch
12+
from PIL import Image
13+
from torch.utils.data import DataLoader
1214

15+
from anomalib.data import ImageBatch, NumpyImageBatch, PredictDataset
1316
from anomalib.deploy import ExportType, OpenVINOInferencer, TorchInferencer
1417
from anomalib.engine import Engine
1518
from anomalib.models import Padim
@@ -108,3 +111,82 @@ def test_openvino_inference(ckpt_path: Callable[[str], Path]) -> None:
108111
for image in openvino_dataloader():
109112
prediction = openvino_inferencer.predict(image)
110113
assert 0.0 <= prediction.pred_score <= 1.0 # confirm if predicted scores are normalized
114+
115+
116+
def compare_predictions(
117+
pred1: ImageBatch | NumpyImageBatch,
118+
pred2: ImageBatch | NumpyImageBatch,
119+
tolerance: float = 1e-3,
120+
) -> None:
121+
"""Compare predictions from two different inference methods."""
122+
score1 = pred1.pred_score if hasattr(pred1, "pred_score") else None
123+
score2 = pred2.pred_score if hasattr(pred2, "pred_score") else None
124+
125+
map1 = pred1.anomaly_map if hasattr(pred1, "anomaly_map") else None
126+
map2 = pred2.anomaly_map if hasattr(pred2, "anomaly_map") else None
127+
128+
if isinstance(map1, torch.Tensor):
129+
map1 = map1.cpu().numpy()
130+
if isinstance(map2, torch.Tensor):
131+
map2 = map2.cpu().numpy()
132+
133+
if score1 is None and score2 is None and map1 is None and map2 is None:
134+
pytest.fail("No predictions found")
135+
136+
if score1 is not None and score2 is not None:
137+
if isinstance(score1, torch.Tensor):
138+
score1 = score1.cpu().item()
139+
if isinstance(score2, torch.Tensor):
140+
score2 = score2.cpu().item()
141+
142+
if score1 is not None and score2 is not None:
143+
score_diff = abs(score1 - score2)
144+
if score_diff > tolerance:
145+
pytest.fail(f"Anomaly score absolute difference: {score_diff:.3f}")
146+
147+
if map1 is not None and map2 is not None:
148+
map_diff = np.abs(map1 - map2)
149+
if np.mean(map_diff) > tolerance:
150+
pytest.fail(f"Anomaly map mean absolute difference: {np.mean(map_diff):.3f}")
151+
152+
153+
def test_inference_similarity(
154+
ckpt_path: Callable[[str], Path],
155+
project_path: Path,
156+
tmp_path: Path,
157+
monkeypatch: pytest.MonkeyPatch,
158+
) -> None:
159+
"""Test inference result."""
160+
# Set TRUST_REMOTE_CODE environment variable for the test
161+
monkeypatch.setenv("TRUST_REMOTE_CODE", "1")
162+
163+
rng = np.random.default_rng()
164+
image = rng.integers(0, 255, (256, 256, 3), dtype=np.uint8)
165+
image = Image.fromarray(image)
166+
test_image_path = tmp_path / "test_image.png"
167+
image.save(test_image_path)
168+
169+
model = Padim()
170+
engine = Engine(logger=False, default_root_dir=project_path, devices=1)
171+
172+
predict_dataset = PredictDataset(test_image_path)
173+
predict_dataloader = DataLoader(
174+
predict_dataset,
175+
batch_size=1,
176+
collate_fn=predict_dataset.collate_fn,
177+
pin_memory=True,
178+
)
179+
engine_pred: list[ImageBatch] = engine.predict(model, dataloaders=predict_dataloader, ckpt_path=ckpt_path("Padim"))
180+
engine_pred = engine_pred[0]
181+
182+
torch_path = engine.export(model, export_type=ExportType.TORCH, export_root=project_path)
183+
torch_inferencer = TorchInferencer(torch_path, device="cpu")
184+
torch_pred = torch_inferencer.predict(test_image_path)
185+
186+
openvino_path = engine.export(model, export_type=ExportType.OPENVINO, export_root=project_path)
187+
openvino_inferencer = OpenVINOInferencer(openvino_path, device="CPU")
188+
openvino_pred = openvino_inferencer.predict(test_image_path)
189+
190+
compare_predictions(engine_pred, torch_pred)
191+
compare_predictions(engine_pred, openvino_pred)
192+
compare_predictions(torch_pred, openvino_pred)

0 commit comments

Comments
 (0)