|
34 | 34 | """ |
35 | 35 |
|
36 | 36 | import logging |
| 37 | +from collections.abc import Generator |
| 38 | +from contextlib import contextmanager |
37 | 39 |
|
38 | 40 | import torch |
39 | 41 |
|
|
45 | 47 | logger = logging.getLogger(__name__) |
46 | 48 |
|
47 | 49 |
|
| 50 | +@contextmanager |
| 51 | +def handle_mac(metric: "_F1AdaptiveThreshold") -> Generator[None, None, None]: |
| 52 | + """Temporarily move tensors to CPU on macOS/MPS and restore after. |
| 53 | +
|
| 54 | + This context manager checks whether the provided metric instance has |
| 55 | + predictions on an MPS device. If so, it moves both predictions and |
| 56 | + targets to CPU for the duration of the context and restores them to |
| 57 | + the original device on exit. |
| 58 | + """ |
| 59 | + # Check if we have any predictions and if they're on MPS |
| 60 | + if bool(metric.preds) and metric.preds[0].is_mps: |
| 61 | + original_device = metric.preds[0].device |
| 62 | + metric.preds = [pred.cpu() for pred in metric.preds] |
| 63 | + metric.target = [target.cpu() for target in metric.target] |
| 64 | + try: |
| 65 | + yield |
| 66 | + finally: |
| 67 | + # Restore to original device |
| 68 | + metric.preds = [pred.to(original_device) for pred in metric.preds] |
| 69 | + metric.target = [target.to(original_device) for target in metric.target] |
| 70 | + else: |
| 71 | + yield |
| 72 | + |
| 73 | + |
48 | 74 | class _F1AdaptiveThreshold(BinaryPrecisionRecallCurve, Threshold): |
49 | 75 | """Adaptive threshold that maximizes F1 score. |
50 | 76 |
|
@@ -94,7 +120,9 @@ def compute(self) -> torch.Tensor: |
94 | 120 | ) |
95 | 121 | logging.warning(msg) |
96 | 122 |
|
97 | | - precision, recall, thresholds = super().compute() |
| 123 | + with handle_mac(self): |
| 124 | + precision, recall, thresholds = super().compute() |
| 125 | + |
98 | 126 | f1_score = (2 * precision * recall) / (precision + recall + 1e-10) |
99 | 127 |
|
100 | 128 | # account for special case where recall is 1.0 even for the highest threshold. |
|
0 commit comments