From a1fed828e68177ba082c1318b911222aeed0c544 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Sat, 13 Sep 2025 12:46:10 -0300 Subject: [PATCH] Improve error messages and error handling for `bmm`. --- test/test_ops_error_message.py | 66 ++++++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 23 ++++-- torch_xla/csrc/tensor_methods.cpp | 125 +++++++++++++++++++----------- torch_xla/csrc/tensor_methods.h | 11 ++- 4 files changed, 167 insertions(+), 58 deletions(-) diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 0ef152920861..988196d7e7bc 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -251,6 +251,72 @@ def test(): expect="""clamp(): expected at least one of `min` or `max` arguments to be specified.""" ) + def test_bmm_raises_error_on_non_3D_tensor_input(self): + device = torch_xla.device() + a = torch.rand(2, 3, 4, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test_a(): + torch.bmm(a[0], b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test_a, + expect="""bmm(): expected `input` f32[3,4] (a 2D tensor), the 1st input tensor, to be a 3D tensor.""" + ) + + def test_b(): + torch.bmm(a, b[0]) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test_b, + expect="""bmm(): expected `mat2` f32[4,3] (a 2D tensor), the 2nd input tensor, to be a 3D tensor.""" + ) + + def test_bmm_raises_error_on_different_batch_dimension(self): + device = torch_xla.device() + a = torch.rand(4, 3, 4, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.bmm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[4,3,4] (batch dimension size: 4), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[2,4,3] (batch dimension size: 2), the 2nd input tensor.""" + ) + + def test_bmm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + a = torch.rand(2, 3, 8, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.bmm(a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,8], the 1st input tensor, and to `mat2` f32[2,4,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (8) to be equal the size of dimension 1 of `mat2` (4).""" + ) + + def test_baddbmm_raises_error_on_incompatible_shapes(self): + device = torch_xla.device() + input = torch.rand(3, 3, device=device) + a = torch.rand(2, 3, 8, device=device) + b = torch.rand(2, 4, 3, device=device) + + def test(): + torch.baddbmm(input, a, b) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""baddbmm(): cannot apply batch matrix-multiplication to `batch1` f32[2,3,8], the 2nd input tensor, and to `batch2` f32[2,4,3], the 3rd input tensor. Expected the size of dimension 2 of `batch1` (8) to be equal the size of dimension 1 of `batch2` (4).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index f23e6eb5f7fb..0feb0e840875 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1254,11 +1254,16 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self, const at::Scalar& beta, const at::Scalar& alpha) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch1, bridge::GetXlaTensor(batch1)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch2, bridge::GetXlaTensor(batch2)); - return bridge::AtenFromXlaTensor( + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch1, + bridge::GetXlaTensor(batch1)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch2, + bridge::GetXlaTensor(batch2)); + XLA_ASSIGN_OR_THROW( + absl_nonnull XLATensorPtr output, tensor_methods::baddbmm(xla_self, xla_batch1, xla_batch2, beta, alpha)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::bernoulli( @@ -1338,9 +1343,13 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self, at::Tensor XLANativeFunctions::bmm(const at::Tensor& self, const at::Tensor& mat2) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2)); - return bridge::AtenFromXlaTensor(tensor_methods::bmm(xla_self, xla_mat2)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_mat2, + bridge::GetXlaTensor(mat2)); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output, + tensor_methods::bmm(xla_self, xla_mat2)); + return bridge::AtenFromXlaTensor(std::move(output)); } at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors, diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2c22169e485d..07e14a37b294 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -166,6 +166,25 @@ namespace torch_xla { namespace tensor_methods { namespace { +struct InputInfo { + const XLATensorPtr& tensor; + std::string_view name; + int position; + + std::string PositionAsStr() const { + switch (position) { + case 1: + return "1st"; + case 2: + return "2nd"; + case 3: + return "3rd"; + default: + return absl::StrCat(position, "th"); + } + } +}; + torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, const xla::Shape& target_shape) { if (GetXlaShape(input).dimensions() == target_shape.dimensions()) { @@ -175,46 +194,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input, input, torch::lazy::ToVector(target_shape.dimensions())); } -void CheckRank(const XLATensorPtr& t, int64_t expected_rank, - const std::string& tag, const std::string& arg_name, - int arg_number) { - int64_t actual_rank = t->shape().get().dimensions_size(); - XLA_CHECK_EQ(actual_rank, expected_rank) - << "Expected " << expected_rank << "-dimensional tensor, but got " - << actual_rank << "-dimensional tensor for " - << "argument #" << arg_number << " '" << arg_name << "'" - << " (while checking arguments for " << tag << ")"; -} - -template -void CheckShapeDimensions(const T& size) { - XLA_CHECK(std::all_of(size.begin(), size.end(), [](int64_t dim) { - return dim >= 0; - })) << "Dimensions cannot be negative numbers"; -} - -void CheckDimensionSize(const XLATensorPtr& t, int64_t dim, - int64_t expected_size, const std::string& tag, - const std::string& arg_name, int arg_number) { - int64_t dim_size = t->size(dim); - XLA_CHECK_EQ(t->size(dim), expected_size) - << "Expected tensor to have size " << expected_size << " at dimension " - << dim << ", but got size " << dim_size << " for " - << "argument #" << arg_number << " '" << arg_name << "'" - << " (while checking arguments for " << tag << ")"; -} - -void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1, - const XLATensorPtr& batch2) { - // Consistent with the checks in bmm_out_or_baddbmm_. - CheckRank(batch1, 3, tag, "batch1", 1); - CheckRank(batch2, 3, tag, "batch2", 2); - CheckDimensionSize(batch2, 0, /*batch_size=*/batch1->size(0), tag, "batch2", - 2); - CheckDimensionSize(batch2, 1, /*contraction_size=*/batch1->size(2), tag, - "batch2", 2); -} - absl::Status CheckExpandValidRank(const XLATensorPtr& input, const absl::Span sizes) { xla::Shape shape = input->shape(); @@ -528,6 +507,18 @@ absl::Status CheckRollShiftsRequired(absl::Span shifts) { return absl::OkStatus(); } +absl::Status CheckInputIs3DTensor(const std::string_view op, + const InputInfo& input) { + int64_t rank = input.tensor->shape().get().dimensions().size(); + if (rank != 3) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): expected `", input.name, "` ", + input.tensor->shape().get().ToString(), " (a ", rank, "D tensor), the ", + input.PositionAsStr(), " input tensor, to be a 3D tensor."))); + } + return absl::OkStatus(); +} + absl::Status CheckRollDimsAndShiftsAreCompatible( absl::Span dims, absl::Span shifts) { if (dims.empty()) { @@ -570,6 +561,39 @@ absl::Status CheckClampMinOrMax(const std::optional& min, return absl::OkStatus(); } +absl::Status CheckBmmInputsAreValid(const std::string_view op, + const InputInfo& input, + const InputInfo& mat2) { + XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, input)); + XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, mat2)); + + if (input.tensor->size(0) != mat2.tensor->size(0)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, + "(): expected the size of the batch dimension (i.e. dimension 0) of `", + input.name, "` ", input.tensor->shape().get().ToString(), + " (batch dimension size: ", input.tensor->size(0), "), the ", + input.PositionAsStr(), + " input tensor, to be the same as the size of the batch dimension of `", + mat2.name, "` ", mat2.tensor->shape().get().ToString(), + " (batch dimension size: ", mat2.tensor->size(0), "), the ", + mat2.PositionAsStr(), " input tensor."))); + } + if (input.tensor->size(2) != mat2.tensor->size(1)) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( + op, "(): cannot apply batch matrix-multiplication to `", input.name, + "` ", input.tensor->shape().get().ToString(), ", the ", + input.PositionAsStr(), " input tensor, and to `", mat2.name, "` ", + mat2.tensor->shape().get().ToString(), ", the ", mat2.PositionAsStr(), + " input tensor. Expected the size of dimension 2 of `", input.name, + "` (", input.tensor->size(2), + ") to be equal the size of dimension 1 of `", mat2.name, "` (", + mat2.tensor->size(1), ")."))); + } + + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -1278,10 +1302,14 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, count_include_pad)); } -XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, - const XLATensorPtr& batch2, const at::Scalar& beta, - const at::Scalar& alpha) { - CheckBmmDimension(/*tag=*/"baddbmm", batch1, batch2); +absl::StatusOr baddbmm(const XLATensorPtr& input, + const XLATensorPtr& batch1, + const XLATensorPtr& batch2, + const at::Scalar& beta, + const at::Scalar& alpha) { + XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid( + "baddbmm", {batch1, /* name= */ "batch1", /* position= */ 2}, + {batch2, /* name= */ "batch2", /* position= */ 3})); torch::lazy::Value product_multiplier = XLAGraphExecutor::Get()->GetIrValueForScalar( alpha, batch1->shape().get().element_type(), batch1->GetDevice()); @@ -1331,9 +1359,12 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) { input->GetIrValue(), other->GetIrValue())); } -XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2) { - CheckBmmDimension(/*tag=*/"bmm", batch1, batch2); - return matmul(batch1, batch2); +absl::StatusOr bmm(const XLATensorPtr& input, + const XLATensorPtr& mat2) { + XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid( + "bmm", {input, /* name= */ "input", /* position= */ 1}, + {mat2, /* name= */ "mat2", /* position= */ 2})); + return matmul(input, mat2); } std::vector broadcast_tensors( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 1dd2208a55d3..fa5a7e0e4c50 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -272,9 +272,11 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop, std::vector padding, bool ceil_mode, bool count_include_pad); -XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1, - const XLATensorPtr& batch2, const at::Scalar& beta, - const at::Scalar& alpha); +absl::StatusOr baddbmm(const XLATensorPtr& input, + const XLATensorPtr& batch1, + const XLATensorPtr& batch2, + const at::Scalar& beta, + const at::Scalar& alpha); XLATensorPtr bernoulli(const XLATensorPtr& input, double probability); XLATensorPtr bernoulli(const XLATensorPtr& input); @@ -297,7 +299,8 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other); // Batch matrix multiplication. Both tensors must be 3D, the batch size must // match and the remaining two dimensions must be compatible for matrix // multiplication. -XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2); +absl::StatusOr bmm(const XLATensorPtr& input, + const XLATensorPtr& mat2); // Broadcasts the given tensors according to broadcasting semantics. std::vector broadcast_tensors(