Skip to content

Commit a1fed82

Browse files
committed
Improve error messages and error handling for bmm.
1 parent 87e631a commit a1fed82

File tree

4 files changed

+167
-58
lines changed

4 files changed

+167
-58
lines changed

test/test_ops_error_message.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,72 @@ def test():
251251
expect="""clamp(): expected at least one of `min` or `max` arguments to be specified."""
252252
)
253253

254+
def test_bmm_raises_error_on_non_3D_tensor_input(self):
255+
device = torch_xla.device()
256+
a = torch.rand(2, 3, 4, device=device)
257+
b = torch.rand(2, 4, 3, device=device)
258+
259+
def test_a():
260+
torch.bmm(a[0], b)
261+
262+
self.assertExpectedRaisesInline(
263+
exc_type=RuntimeError,
264+
callable=test_a,
265+
expect="""bmm(): expected `input` f32[3,4] (a 2D tensor), the 1st input tensor, to be a 3D tensor."""
266+
)
267+
268+
def test_b():
269+
torch.bmm(a, b[0])
270+
271+
self.assertExpectedRaisesInline(
272+
exc_type=RuntimeError,
273+
callable=test_b,
274+
expect="""bmm(): expected `mat2` f32[4,3] (a 2D tensor), the 2nd input tensor, to be a 3D tensor."""
275+
)
276+
277+
def test_bmm_raises_error_on_different_batch_dimension(self):
278+
device = torch_xla.device()
279+
a = torch.rand(4, 3, 4, device=device)
280+
b = torch.rand(2, 4, 3, device=device)
281+
282+
def test():
283+
torch.bmm(a, b)
284+
285+
self.assertExpectedRaisesInline(
286+
exc_type=RuntimeError,
287+
callable=test,
288+
expect="""bmm(): expected the size of the batch dimension (i.e. dimension 0) of `input` f32[4,3,4] (batch dimension size: 4), the 1st input tensor, to be the same as the size of the batch dimension of `mat2` f32[2,4,3] (batch dimension size: 2), the 2nd input tensor."""
289+
)
290+
291+
def test_bmm_raises_error_on_incompatible_shapes(self):
292+
device = torch_xla.device()
293+
a = torch.rand(2, 3, 8, device=device)
294+
b = torch.rand(2, 4, 3, device=device)
295+
296+
def test():
297+
torch.bmm(a, b)
298+
299+
self.assertExpectedRaisesInline(
300+
exc_type=RuntimeError,
301+
callable=test,
302+
expect="""bmm(): cannot apply batch matrix-multiplication to `input` f32[2,3,8], the 1st input tensor, and to `mat2` f32[2,4,3], the 2nd input tensor. Expected the size of dimension 2 of `input` (8) to be equal the size of dimension 1 of `mat2` (4)."""
303+
)
304+
305+
def test_baddbmm_raises_error_on_incompatible_shapes(self):
306+
device = torch_xla.device()
307+
input = torch.rand(3, 3, device=device)
308+
a = torch.rand(2, 3, 8, device=device)
309+
b = torch.rand(2, 4, 3, device=device)
310+
311+
def test():
312+
torch.baddbmm(input, a, b)
313+
314+
self.assertExpectedRaisesInline(
315+
exc_type=RuntimeError,
316+
callable=test,
317+
expect="""baddbmm(): cannot apply batch matrix-multiplication to `batch1` f32[2,3,8], the 2nd input tensor, and to `batch2` f32[2,4,3], the 3rd input tensor. Expected the size of dimension 2 of `batch1` (8) to be equal the size of dimension 1 of `batch2` (4)."""
318+
)
319+
254320

255321
if __name__ == "__main__":
256322
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,11 +1254,16 @@ at::Tensor XLANativeFunctions::baddbmm(const at::Tensor& self,
12541254
const at::Scalar& beta,
12551255
const at::Scalar& alpha) {
12561256
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1257-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1258-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch1, bridge::GetXlaTensor(batch1));
1259-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_batch2, bridge::GetXlaTensor(batch2));
1260-
return bridge::AtenFromXlaTensor(
1257+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1258+
bridge::GetXlaTensor(self));
1259+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch1,
1260+
bridge::GetXlaTensor(batch1));
1261+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_batch2,
1262+
bridge::GetXlaTensor(batch2));
1263+
XLA_ASSIGN_OR_THROW(
1264+
absl_nonnull XLATensorPtr output,
12611265
tensor_methods::baddbmm(xla_self, xla_batch1, xla_batch2, beta, alpha));
1266+
return bridge::AtenFromXlaTensor(std::move(output));
12621267
}
12631268

12641269
at::Tensor XLANativeFunctions::bernoulli(
@@ -1338,9 +1343,13 @@ at::Tensor XLANativeFunctions::bitwise_xor(const at::Tensor& self,
13381343
at::Tensor XLANativeFunctions::bmm(const at::Tensor& self,
13391344
const at::Tensor& mat2) {
13401345
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1341-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1342-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2));
1343-
return bridge::AtenFromXlaTensor(tensor_methods::bmm(xla_self, xla_mat2));
1346+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1347+
bridge::GetXlaTensor(self));
1348+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_mat2,
1349+
bridge::GetXlaTensor(mat2));
1350+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1351+
tensor_methods::bmm(xla_self, xla_mat2));
1352+
return bridge::AtenFromXlaTensor(std::move(output));
13441353
}
13451354

13461355
at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,25 @@ namespace torch_xla {
166166
namespace tensor_methods {
167167
namespace {
168168

169+
struct InputInfo {
170+
const XLATensorPtr& tensor;
171+
std::string_view name;
172+
int position;
173+
174+
std::string PositionAsStr() const {
175+
switch (position) {
176+
case 1:
177+
return "1st";
178+
case 2:
179+
return "2nd";
180+
case 3:
181+
return "3rd";
182+
default:
183+
return absl::StrCat(position, "th");
184+
}
185+
}
186+
};
187+
169188
torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
170189
const xla::Shape& target_shape) {
171190
if (GetXlaShape(input).dimensions() == target_shape.dimensions()) {
@@ -175,46 +194,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
175194
input, torch::lazy::ToVector<int64_t>(target_shape.dimensions()));
176195
}
177196

178-
void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
179-
const std::string& tag, const std::string& arg_name,
180-
int arg_number) {
181-
int64_t actual_rank = t->shape().get().dimensions_size();
182-
XLA_CHECK_EQ(actual_rank, expected_rank)
183-
<< "Expected " << expected_rank << "-dimensional tensor, but got "
184-
<< actual_rank << "-dimensional tensor for "
185-
<< "argument #" << arg_number << " '" << arg_name << "'"
186-
<< " (while checking arguments for " << tag << ")";
187-
}
188-
189-
template <typename T>
190-
void CheckShapeDimensions(const T& size) {
191-
XLA_CHECK(std::all_of(size.begin(), size.end(), [](int64_t dim) {
192-
return dim >= 0;
193-
})) << "Dimensions cannot be negative numbers";
194-
}
195-
196-
void CheckDimensionSize(const XLATensorPtr& t, int64_t dim,
197-
int64_t expected_size, const std::string& tag,
198-
const std::string& arg_name, int arg_number) {
199-
int64_t dim_size = t->size(dim);
200-
XLA_CHECK_EQ(t->size(dim), expected_size)
201-
<< "Expected tensor to have size " << expected_size << " at dimension "
202-
<< dim << ", but got size " << dim_size << " for "
203-
<< "argument #" << arg_number << " '" << arg_name << "'"
204-
<< " (while checking arguments for " << tag << ")";
205-
}
206-
207-
void CheckBmmDimension(const std::string& tag, const XLATensorPtr& batch1,
208-
const XLATensorPtr& batch2) {
209-
// Consistent with the checks in bmm_out_or_baddbmm_.
210-
CheckRank(batch1, 3, tag, "batch1", 1);
211-
CheckRank(batch2, 3, tag, "batch2", 2);
212-
CheckDimensionSize(batch2, 0, /*batch_size=*/batch1->size(0), tag, "batch2",
213-
2);
214-
CheckDimensionSize(batch2, 1, /*contraction_size=*/batch1->size(2), tag,
215-
"batch2", 2);
216-
}
217-
218197
absl::Status CheckExpandValidRank(const XLATensorPtr& input,
219198
const absl::Span<const int64_t> sizes) {
220199
xla::Shape shape = input->shape();
@@ -528,6 +507,18 @@ absl::Status CheckRollShiftsRequired(absl::Span<const int64_t> shifts) {
528507
return absl::OkStatus();
529508
}
530509

510+
absl::Status CheckInputIs3DTensor(const std::string_view op,
511+
const InputInfo& input) {
512+
int64_t rank = input.tensor->shape().get().dimensions().size();
513+
if (rank != 3) {
514+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
515+
op, "(): expected `", input.name, "` ",
516+
input.tensor->shape().get().ToString(), " (a ", rank, "D tensor), the ",
517+
input.PositionAsStr(), " input tensor, to be a 3D tensor.")));
518+
}
519+
return absl::OkStatus();
520+
}
521+
531522
absl::Status CheckRollDimsAndShiftsAreCompatible(
532523
absl::Span<const int64_t> dims, absl::Span<const int64_t> shifts) {
533524
if (dims.empty()) {
@@ -570,6 +561,39 @@ absl::Status CheckClampMinOrMax(const std::optional<at::Scalar>& min,
570561
return absl::OkStatus();
571562
}
572563

564+
absl::Status CheckBmmInputsAreValid(const std::string_view op,
565+
const InputInfo& input,
566+
const InputInfo& mat2) {
567+
XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, input));
568+
XLA_RETURN_IF_ERROR(CheckInputIs3DTensor(op, mat2));
569+
570+
if (input.tensor->size(0) != mat2.tensor->size(0)) {
571+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
572+
op,
573+
"(): expected the size of the batch dimension (i.e. dimension 0) of `",
574+
input.name, "` ", input.tensor->shape().get().ToString(),
575+
" (batch dimension size: ", input.tensor->size(0), "), the ",
576+
input.PositionAsStr(),
577+
" input tensor, to be the same as the size of the batch dimension of `",
578+
mat2.name, "` ", mat2.tensor->shape().get().ToString(),
579+
" (batch dimension size: ", mat2.tensor->size(0), "), the ",
580+
mat2.PositionAsStr(), " input tensor.")));
581+
}
582+
if (input.tensor->size(2) != mat2.tensor->size(1)) {
583+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
584+
op, "(): cannot apply batch matrix-multiplication to `", input.name,
585+
"` ", input.tensor->shape().get().ToString(), ", the ",
586+
input.PositionAsStr(), " input tensor, and to `", mat2.name, "` ",
587+
mat2.tensor->shape().get().ToString(), ", the ", mat2.PositionAsStr(),
588+
" input tensor. Expected the size of dimension 2 of `", input.name,
589+
"` (", input.tensor->size(2),
590+
") to be equal the size of dimension 1 of `", mat2.name, "` (",
591+
mat2.tensor->size(1), ").")));
592+
}
593+
594+
return absl::OkStatus();
595+
}
596+
573597
} // namespace
574598

575599
//////////////////////////////////////////////////////////////////////////////
@@ -1278,10 +1302,14 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
12781302
count_include_pad));
12791303
}
12801304

1281-
XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1,
1282-
const XLATensorPtr& batch2, const at::Scalar& beta,
1283-
const at::Scalar& alpha) {
1284-
CheckBmmDimension(/*tag=*/"baddbmm", batch1, batch2);
1305+
absl::StatusOr<absl_nonnull XLATensorPtr> baddbmm(const XLATensorPtr& input,
1306+
const XLATensorPtr& batch1,
1307+
const XLATensorPtr& batch2,
1308+
const at::Scalar& beta,
1309+
const at::Scalar& alpha) {
1310+
XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid(
1311+
"baddbmm", {batch1, /* name= */ "batch1", /* position= */ 2},
1312+
{batch2, /* name= */ "batch2", /* position= */ 3}));
12851313
torch::lazy::Value product_multiplier =
12861314
XLAGraphExecutor::Get()->GetIrValueForScalar(
12871315
alpha, batch1->shape().get().element_type(), batch1->GetDevice());
@@ -1331,9 +1359,12 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) {
13311359
input->GetIrValue(), other->GetIrValue()));
13321360
}
13331361

1334-
XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2) {
1335-
CheckBmmDimension(/*tag=*/"bmm", batch1, batch2);
1336-
return matmul(batch1, batch2);
1362+
absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
1363+
const XLATensorPtr& mat2) {
1364+
XLA_RETURN_IF_ERROR(CheckBmmInputsAreValid(
1365+
"bmm", {input, /* name= */ "input", /* position= */ 1},
1366+
{mat2, /* name= */ "mat2", /* position= */ 2}));
1367+
return matmul(input, mat2);
13371368
}
13381369

13391370
std::vector<XLATensorPtr> broadcast_tensors(

torch_xla/csrc/tensor_methods.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,11 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
272272
std::vector<int64_t> padding, bool ceil_mode,
273273
bool count_include_pad);
274274

275-
XLATensorPtr baddbmm(const XLATensorPtr& input, const XLATensorPtr& batch1,
276-
const XLATensorPtr& batch2, const at::Scalar& beta,
277-
const at::Scalar& alpha);
275+
absl::StatusOr<absl_nonnull XLATensorPtr> baddbmm(const XLATensorPtr& input,
276+
const XLATensorPtr& batch1,
277+
const XLATensorPtr& batch2,
278+
const at::Scalar& beta,
279+
const at::Scalar& alpha);
278280

279281
XLATensorPtr bernoulli(const XLATensorPtr& input, double probability);
280282
XLATensorPtr bernoulli(const XLATensorPtr& input);
@@ -297,7 +299,8 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other);
297299
// Batch matrix multiplication. Both tensors must be 3D, the batch size must
298300
// match and the remaining two dimensions must be compatible for matrix
299301
// multiplication.
300-
XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2);
302+
absl::StatusOr<absl_nonnull XLATensorPtr> bmm(const XLATensorPtr& input,
303+
const XLATensorPtr& mat2);
301304

302305
// Broadcasts the given tensors according to broadcasting semantics.
303306
std::vector<XLATensorPtr> broadcast_tensors(

0 commit comments

Comments
 (0)