Skip to content

Commit 206d7e2

Browse files
committed
Fix(backend/torch): Resolve MPS broadcast crash in binary_crossentropy
1 parent 6d06085 commit 206d7e2

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

examples/mre_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
3+
os.environ["KERAS_BACKEND"] = "torch"
4+
import numpy as np
5+
import keras
6+
from keras import layers, Model
7+
import torch
8+
9+
print("Keras:", keras.__version__)
10+
print("PyTorch:", torch.__version__)
11+
print("MPS available:", torch.backends.mps.is_available())
12+
13+
B, T, F = 32, 50, 10
14+
inp = layers.Input(shape=(T, F))
15+
x = layers.Dense(1, activation=None)(inp) # (B,T,1)
16+
x = layers.Activation("sigmoid")(x)
17+
model = Model(inp, x)
18+
19+
model.compile(optimizer="adam", loss="binary_crossentropy")
20+
21+
X = np.random.randn(B, T, F).astype(np.float32)
22+
y = np.random.randint(0, 2, (B, T, 1)).astype(np.float32)
23+
24+
print("X:", X.shape, "y:", y.shape)
25+
model.fit(X, y, epochs=1, batch_size=32, verbose=1)
26+
print("\n✅ Test finished successfully! The fix works.")

keras/src/backend/torch/nn.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,12 +755,27 @@ def binary_crossentropy(target, output, from_logits=False):
755755
target = convert_to_tensor(target)
756756
output = convert_to_tensor(output)
757757

758+
# Fix for MPS broadcast error:
759+
# The backward pass for BCELoss on MPS fails if inputs have a
760+
# trailing dim of 1 (e.g., (B, T, 1) or (B, H, W, 1)).
761+
# Squeezing to (B, T) or (B, H, W) resolves the conflict.
762+
# .contiguous() is added to force a new tensor copy.
763+
if (
764+
target.ndim > 1
765+
and output.ndim == target.ndim
766+
and target.shape[-1] == 1
767+
and output.shape[-1] == 1
768+
):
769+
target = torch.squeeze(target, -1).contiguous()
770+
output = torch.squeeze(output, -1).contiguous()
771+
758772
if target.shape != output.shape:
759773
raise ValueError(
760774
"Arguments `target` and `output` must have the same shape. "
761775
"Received: "
762776
f"target.shape={target.shape}, output.shape={output.shape}"
763777
)
778+
764779
# By default, PyTorch, does reduction of `sum` over all rows,
765780
# change reduction to `none` to keep dim
766781
if from_logits:
@@ -771,7 +786,6 @@ def binary_crossentropy(target, output, from_logits=False):
771786
output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
772787
return tnn.binary_cross_entropy(output, target, reduction="none")
773788

774-
775789
def moments(x, axes, keepdims=False, synchronized=False):
776790
if synchronized:
777791
raise NotImplementedError(

0 commit comments

Comments
 (0)