@@ -755,12 +755,27 @@ def binary_crossentropy(target, output, from_logits=False):
755755 target = convert_to_tensor (target )
756756 output = convert_to_tensor (output )
757757
758+ # Fix for MPS broadcast error:
759+ # The backward pass for BCELoss on MPS fails if inputs have a
760+ # trailing dim of 1 (e.g., (B, T, 1) or (B, H, W, 1)).
761+ # Squeezing to (B, T) or (B, H, W) resolves the conflict.
762+ # .contiguous() is added to force a new tensor copy.
763+ if (
764+ target .ndim > 1
765+ and output .ndim == target .ndim
766+ and target .shape [- 1 ] == 1
767+ and output .shape [- 1 ] == 1
768+ ):
769+ target = torch .squeeze (target , - 1 ).contiguous ()
770+ output = torch .squeeze (output , - 1 ).contiguous ()
771+
758772 if target .shape != output .shape :
759773 raise ValueError (
760774 "Arguments `target` and `output` must have the same shape. "
761775 "Received: "
762776 f"target.shape={ target .shape } , output.shape={ output .shape } "
763777 )
778+
764779 # By default, PyTorch, does reduction of `sum` over all rows,
765780 # change reduction to `none` to keep dim
766781 if from_logits :
@@ -771,7 +786,6 @@ def binary_crossentropy(target, output, from_logits=False):
771786 output = torch .clip (output , backend .epsilon (), 1.0 - backend .epsilon ())
772787 return tnn .binary_cross_entropy (output , target , reduction = "none" )
773788
774-
775789def moments (x , axes , keepdims = False , synchronized = False ):
776790 if synchronized :
777791 raise NotImplementedError (
0 commit comments