Skip to content

Commit 163f8bc

Browse files
authored
🐛 fix(metrics): disable mps for torch metrics (#3019)
🐛 fix(metrics): disable mps for torch metrics (#3018) * move to cpu when device is mps --------- Signed-off-by: Ma, Xiangxiang <xiangxiang.ma@intel.com>
1 parent bb2c066 commit 163f8bc

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

src/anomalib/metrics/threshold/f1_adaptive_threshold.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
"""
3535

3636
import logging
37+
from collections.abc import Generator
38+
from contextlib import contextmanager
3739

3840
import torch
3941

@@ -45,6 +47,30 @@
4547
logger = logging.getLogger(__name__)
4648

4749

50+
@contextmanager
51+
def handle_mac(metric: "_F1AdaptiveThreshold") -> Generator[None, None, None]:
52+
"""Temporarily move tensors to CPU on macOS/MPS and restore after.
53+
54+
This context manager checks whether the provided metric instance has
55+
predictions on an MPS device. If so, it moves both predictions and
56+
targets to CPU for the duration of the context and restores them to
57+
the original device on exit.
58+
"""
59+
# Check if we have any predictions and if they're on MPS
60+
if bool(metric.preds) and metric.preds[0].is_mps:
61+
original_device = metric.preds[0].device
62+
metric.preds = [pred.cpu() for pred in metric.preds]
63+
metric.target = [target.cpu() for target in metric.target]
64+
try:
65+
yield
66+
finally:
67+
# Restore to original device
68+
metric.preds = [pred.to(original_device) for pred in metric.preds]
69+
metric.target = [target.to(original_device) for target in metric.target]
70+
else:
71+
yield
72+
73+
4874
class _F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold):
4975
"""Adaptive threshold that maximizes F1 score.
5076
@@ -94,7 +120,9 @@ def compute(self) -> torch.Tensor:
94120
)
95121
logging.warning(msg)
96122

97-
precision, recall, thresholds = super().compute()
123+
with handle_mac(self):
124+
precision, recall, thresholds = super().compute()
125+
98126
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
99127

100128
# account for special case where recall is 1.0 even for the highest threshold.

0 commit comments

Comments
 (0)