Skip to content

Commit 74a2421

Browse files
andreromSunMarc
andauthored
Add bfloat16 support detection for MPS in is_torch_bf16_gpu_available() (#40458)
* Add bfloat16 support detection for MPS (Apple Silicon) in is_torch_bf16_gpu_available bfloat16 seems to have been supported for a few years now in Metal and torch.mps. Make sure to allow it and not throw on bf16 usage with "Your setup doesn't support bf16/gpu." from TrainingArguments. * Check bf16 support for MPS using torch method Actually seems method exists: https://github.com/pytorch/pytorch/blob/5859edf1130cec5a021ace5d5b0e18144808f757/torch/_dynamo/device_interface.py#L519 It simply checks if you are on MacOs 14 or higher. * Document Metal emulation for bf16 support Add note about Metal emulation for bf16 support on M1/M2. * Update bf16 support check for MPS backend is_bf16_supported() not exposed even if defined on MPSInterface, use same approach as in accelerate pr. --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
1 parent ffdd10f commit 74a2421

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

src/transformers/utils/import_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ def is_torch_bf16_gpu_available() -> bool:
636636
return True
637637
if is_torch_npu_available():
638638
return torch.npu.is_bf16_supported()
639+
if is_torch_mps_available():
640+
# Note: Emulated in software by Metal using fp32 for hardware without native support (like M1/M2)
641+
return torch.backends.mps.is_macos_or_newer(14, 0)
639642
return False
640643

641644

0 commit comments

Comments
 (0)