Skip to content

Commit d2f4364

Browse files
committed
Fix(backend/torch): Resolve MPS broadcast crash in binary_crossentropy
1 parent 6d06085 commit d2f4364

File tree

1 file changed

+15
-1
lines changed
  • keras/src/backend/torch

1 file changed

+15
-1
lines changed

keras/src/backend/torch/nn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,12 +755,27 @@ def binary_crossentropy(target, output, from_logits=False):
755755
target = convert_to_tensor(target)
756756
output = convert_to_tensor(output)
757757

758+
# Fix for MPS broadcast error:
759+
# The backward pass for BCELoss on MPS fails if inputs are (B, T, 1).
760+
# We squeeze both to (B, T).
761+
# .contiguous() is added to force a new tensor copy, as the backward
762+
# pass seems to be using the original tensor's shape (a view bug).
763+
if (
764+
target.ndim == 3
765+
and target.shape[-1] == 1
766+
and output.ndim == 3
767+
and output.shape[-1] == 1
768+
):
769+
target = torch.squeeze(target, -1).contiguous()
770+
output = torch.squeeze(output, -1).contiguous()
771+
758772
if target.shape != output.shape:
759773
raise ValueError(
760774
"Arguments `target` and `output` must have the same shape. "
761775
"Received: "
762776
f"target.shape={target.shape}, output.shape={output.shape}"
763777
)
778+
764779
# By default, PyTorch, does reduction of `sum` over all rows,
765780
# change reduction to `none` to keep dim
766781
if from_logits:
@@ -771,7 +786,6 @@ def binary_crossentropy(target, output, from_logits=False):
771786
output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
772787
return tnn.binary_cross_entropy(output, target, reduction="none")
773788

774-
775789
def moments(x, axes, keepdims=False, synchronized=False):
776790
if synchronized:
777791
raise NotImplementedError(

0 commit comments

Comments
 (0)