You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add bfloat16 support detection for MPS in is_torch_bf16_gpu_available() (#40458)
* Add bfloat16 support detection for MPS (Apple Silicon) in is_torch_bf16_gpu_available
bfloat16 seems to have been supported for a few years now in Metal and torch.mps.
Make sure to allow it and not throw on bf16 usage with "Your setup doesn't support bf16/gpu." from TrainingArguments.
* Check bf16 support for MPS using torch method
Actually seems method exists: https://github.com/pytorch/pytorch/blob/5859edf1130cec5a021ace5d5b0e18144808f757/torch/_dynamo/device_interface.py#L519
It simply checks if you are on MacOs 14 or higher.
* Document Metal emulation for bf16 support
Add note about Metal emulation for bf16 support on M1/M2.
* Update bf16 support check for MPS backend
is_bf16_supported() not exposed even if defined on MPSInterface, use same approach as in accelerate pr.
---------
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
0 commit comments