Skip to content

Conversation

@Mithil27360
Copy link
Contributor

Fixes #21795

This PR resolves a broadcast error on the PyTorch MPS backend for binary_crossentropy. The crash occurred during backpropagation when y_true and y_pred had incompatible shapes for broadcasting, such as (B, T, 1) and (B, T).

This fix aligns the shapes by squeezing the trailing dimension of 1 from both tensors and calling .contiguous() to ensure the new shape is respected during the backward pass. This resolves the mps.multiply broadcast error.

@google-cla
Copy link

google-cla bot commented Nov 2, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Mithil27360, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical bug in the PyTorch MPS backend's binary_crossentropy function, which previously led to crashes due to tensor shape mismatches during backpropagation. By explicitly handling and aligning specific tensor dimensions, the change ensures the robustness and reliability of the binary_crossentropy operation when utilizing Apple Silicon's Metal Performance Shaders, thereby improving the overall stability of the Keras backend on this hardware.

Highlights

  • Fix MPS Broadcast Error: Resolved a crash in the PyTorch MPS backend for binary_crossentropy that occurred during backpropagation when y_true and y_pred had incompatible shapes (e.g., (B, T, 1) and (B, T)). The fix involves squeezing the trailing dimension of 1 from both tensors and calling .contiguous() to ensure the new shape is respected during the backward pass, preventing mps.multiply broadcast errors.
  • Tensor Shape Alignment: Introduced conditional logic within binary_crossentropy to detect and correct specific 3D tensor shapes (B, T, 1) by reducing them to (B, T) using torch.squeeze and contiguous() for both target and output tensors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@codecov-commenter
Copy link

codecov-commenter commented Nov 2, 2025

Codecov Report

❌ Patch coverage is 0% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.66%. Comparing base (6d06085) to head (e65c475).

Files with missing lines Patch % Lines
keras/src/backend/torch/nn.py 0.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21816      +/-   ##
==========================================
- Coverage   82.66%   82.66%   -0.01%     
==========================================
  Files         577      577              
  Lines       59419    59422       +3     
  Branches     9313     9314       +1     
==========================================
  Hits        49121    49121              
- Misses       7898     7900       +2     
- Partials     2400     2401       +1     
Flag Coverage Δ
keras 82.48% <0.00%> (-0.01%) ⬇️
keras-jax 63.32% <0.00%> (-0.01%) ⬇️
keras-numpy 57.57% <0.00%> (+<0.01%) ⬆️
keras-openvino 34.34% <0.00%> (-0.01%) ⬇️
keras-tensorflow 64.12% <0.00%> (-0.01%) ⬇️
keras-torch 63.62% <0.00%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Mithil27360 Mithil27360 force-pushed the fix-mps-binary-crossentropy branch from d2f4364 to db2f77e Compare November 2, 2025 10:51
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a broadcast crash in binary_crossentropy on the PyTorch MPS backend. The fix involves squeezing tensors with a trailing dimension of size 1 before the loss calculation. The change is well-commented and correctly uses .contiguous() to work around a suspected view bug in the backward pass. My main feedback is to suggest generalizing the condition to handle tensors of any rank, not just 3D, to make the fix more robust for other use cases like image segmentation.

Comment on lines 763 to 767
if (
target.ndim == 3
and target.shape[-1] == 1
and output.ndim == 3
and output.shape[-1] == 1
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition to apply the squeeze operation is specific to 3D tensors (target.ndim == 3). However, the MPS broadcast issue with trailing dimensions of size 1 might also occur for tensors of other ranks, such as 4D tensors (B, H, W, 1) common in segmentation tasks. To make this fix more robust and future-proof, consider generalizing the condition to apply to any tensor with a rank greater than 1.

Suggested change
if (
target.ndim == 3
and target.shape[-1] == 1
and output.ndim == 3
and output.shape[-1] == 1
):
if (
target.ndim > 1
and output.ndim == target.ndim
and target.shape[-1] == 1
and output.shape[-1] == 1
):

@Mithil27360 Mithil27360 force-pushed the fix-mps-binary-crossentropy branch 2 times, most recently from 206d7e2 to dce518f Compare November 2, 2025 11:29
@Mithil27360 Mithil27360 force-pushed the fix-mps-binary-crossentropy branch from dce518f to e65c475 Compare November 2, 2025 12:03
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 3, 2025
@fchollet fchollet merged commit 45909f9 into keras-team:master Nov 3, 2025
8 checks passed
@Mithil27360
Copy link
Contributor Author

Hi @fchollet ,

Thank you for reviewing and merging PR #21816! I learned a lot from implementing the MPS broadcast fix and addressing the feedback .Looking forward to contributing more!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kokoro:force-run ready to pull Ready to be merged into the codebase size:S

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Torch MPS: binary_crossentropy crashes with mps.multiply broadcast ((B,T,1) × (B,T))

4 participants