-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Fix(backend/torch): Resolved MPS broadcast crash in binary_crossentropy #21816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix(backend/torch): Resolved MPS broadcast crash in binary_crossentropy #21816
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Summary of ChangesHello @Mithil27360, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical bug in the PyTorch MPS backend's Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21816 +/- ##
==========================================
- Coverage 82.66% 82.66% -0.01%
==========================================
Files 577 577
Lines 59419 59422 +3
Branches 9313 9314 +1
==========================================
Hits 49121 49121
- Misses 7898 7900 +2
- Partials 2400 2401 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d2f4364 to
db2f77e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a broadcast crash in binary_crossentropy on the PyTorch MPS backend. The fix involves squeezing tensors with a trailing dimension of size 1 before the loss calculation. The change is well-commented and correctly uses .contiguous() to work around a suspected view bug in the backward pass. My main feedback is to suggest generalizing the condition to handle tensors of any rank, not just 3D, to make the fix more robust for other use cases like image segmentation.
| if ( | ||
| target.ndim == 3 | ||
| and target.shape[-1] == 1 | ||
| and output.ndim == 3 | ||
| and output.shape[-1] == 1 | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition to apply the squeeze operation is specific to 3D tensors (target.ndim == 3). However, the MPS broadcast issue with trailing dimensions of size 1 might also occur for tensors of other ranks, such as 4D tensors (B, H, W, 1) common in segmentation tasks. To make this fix more robust and future-proof, consider generalizing the condition to apply to any tensor with a rank greater than 1.
| if ( | |
| target.ndim == 3 | |
| and target.shape[-1] == 1 | |
| and output.ndim == 3 | |
| and output.shape[-1] == 1 | |
| ): | |
| if ( | |
| target.ndim > 1 | |
| and output.ndim == target.ndim | |
| and target.shape[-1] == 1 | |
| and output.shape[-1] == 1 | |
| ): |
206d7e2 to
dce518f
Compare
dce518f to
e65c475
Compare
Fixes #21795
This PR resolves a broadcast error on the PyTorch MPS backend for
binary_crossentropy. The crash occurred during backpropagation wheny_trueandy_predhad incompatible shapes for broadcasting, such as(B, T, 1)and(B, T).This fix aligns the shapes by squeezing the trailing dimension of 1 from both tensors and calling
.contiguous()to ensure the new shape is respected during the backward pass. This resolves themps.multiplybroadcast error.