Skip to content

Commit dce518f

Browse files
committed
Fix(backend/torch): Resolve MPS broadcast crash in binary_crossentropy
1 parent 6d06085 commit dce518f

File tree

1 file changed

+15
-0
lines changed
  • keras/src/backend/torch

1 file changed

+15
-0
lines changed

keras/src/backend/torch/nn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,12 +755,27 @@ def binary_crossentropy(target, output, from_logits=False):
755755
target = convert_to_tensor(target)
756756
output = convert_to_tensor(output)
757757

758+
# Fix for MPS broadcast error:
759+
# The backward pass for BCELoss on MPS fails if inputs have a
760+
# trailing dim of 1 (e.g., (B, T, 1) or (B, H, W, 1)).
761+
# Squeezing to (B, T) or (B, H, W) resolves the conflict.
762+
# .contiguous() is added to force a new tensor copy.
763+
if (
764+
target.ndim > 1
765+
and output.ndim == target.ndim
766+
and target.shape[-1] == 1
767+
and output.shape[-1] == 1
768+
):
769+
target = torch.squeeze(target, -1).contiguous()
770+
output = torch.squeeze(output, -1).contiguous()
771+
758772
if target.shape != output.shape:
759773
raise ValueError(
760774
"Arguments `target` and `output` must have the same shape. "
761775
"Received: "
762776
f"target.shape={target.shape}, output.shape={output.shape}"
763777
)
778+
764779
# By default, PyTorch, does reduction of `sum` over all rows,
765780
# change reduction to `none` to keep dim
766781
if from_logits:

0 commit comments

Comments
 (0)