@@ -166,6 +166,25 @@ namespace torch_xla {
166166namespace tensor_methods {
167167namespace {
168168
169+ struct InputInfo {
170+ const XLATensorPtr& tensor;
171+ std::string_view name;
172+ int position;
173+
174+ std::string PositionAsStr () const {
175+ switch (position) {
176+ case 1 :
177+ return " 1st" ;
178+ case 2 :
179+ return " 2nd" ;
180+ case 3 :
181+ return " 3rd" ;
182+ default :
183+ return absl::StrCat (position, " th" );
184+ }
185+ }
186+ };
187+
169188torch::lazy::Value MaybeExpand (const torch::lazy::Value& input,
170189 const xla::Shape& target_shape) {
171190 if (GetXlaShape (input).dimensions () == target_shape.dimensions ()) {
@@ -175,46 +194,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
175194 input, torch::lazy::ToVector<int64_t >(target_shape.dimensions ()));
176195}
177196
178- void CheckRank (const XLATensorPtr& t, int64_t expected_rank,
179- const std::string& tag, const std::string& arg_name,
180- int arg_number) {
181- int64_t actual_rank = t->shape ().get ().dimensions_size ();
182- XLA_CHECK_EQ (actual_rank, expected_rank)
183- << " Expected " << expected_rank << " -dimensional tensor, but got "
184- << actual_rank << " -dimensional tensor for "
185- << " argument #" << arg_number << " '" << arg_name << " '"
186- << " (while checking arguments for " << tag << " )" ;
187- }
188-
189- template <typename T>
190- void CheckShapeDimensions (const T& size) {
191- XLA_CHECK (std::all_of (size.begin (), size.end (), [](int64_t dim) {
192- return dim >= 0 ;
193- })) << " Dimensions cannot be negative numbers" ;
194- }
195-
196- void CheckDimensionSize (const XLATensorPtr& t, int64_t dim,
197- int64_t expected_size, const std::string& tag,
198- const std::string& arg_name, int arg_number) {
199- int64_t dim_size = t->size (dim);
200- XLA_CHECK_EQ (t->size (dim), expected_size)
201- << " Expected tensor to have size " << expected_size << " at dimension "
202- << dim << " , but got size " << dim_size << " for "
203- << " argument #" << arg_number << " '" << arg_name << " '"
204- << " (while checking arguments for " << tag << " )" ;
205- }
206-
207- void CheckBmmDimension (const std::string& tag, const XLATensorPtr& batch1,
208- const XLATensorPtr& batch2) {
209- // Consistent with the checks in bmm_out_or_baddbmm_.
210- CheckRank (batch1, 3 , tag, " batch1" , 1 );
211- CheckRank (batch2, 3 , tag, " batch2" , 2 );
212- CheckDimensionSize (batch2, 0 , /* batch_size=*/ batch1->size (0 ), tag, " batch2" ,
213- 2 );
214- CheckDimensionSize (batch2, 1 , /* contraction_size=*/ batch1->size (2 ), tag,
215- " batch2" , 2 );
216- }
217-
218197absl::Status CheckExpandValidRank (const XLATensorPtr& input,
219198 const absl::Span<const int64_t > sizes) {
220199 xla::Shape shape = input->shape ();
@@ -528,6 +507,18 @@ absl::Status CheckRollShiftsRequired(absl::Span<const int64_t> shifts) {
528507 return absl::OkStatus ();
529508}
530509
510+ absl::Status CheckInputIs3DTensor (const std::string_view op,
511+ const InputInfo& input) {
512+ int64_t rank = input.tensor ->shape ().get ().dimensions ().size ();
513+ if (rank != 3 ) {
514+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
515+ op, " (): expected `" , input.name , " ` " ,
516+ input.tensor ->shape ().get ().ToString (), " (a " , rank, " D tensor), the " ,
517+ input.PositionAsStr (), " input tensor, to be a 3D tensor." )));
518+ }
519+ return absl::OkStatus ();
520+ }
521+
531522absl::Status CheckRollDimsAndShiftsAreCompatible (
532523 absl::Span<const int64_t > dims, absl::Span<const int64_t > shifts) {
533524 if (dims.empty ()) {
@@ -570,6 +561,39 @@ absl::Status CheckClampMinOrMax(const std::optional<at::Scalar>& min,
570561 return absl::OkStatus ();
571562}
572563
564+ absl::Status CheckBmmInputsAreValid (const std::string_view op,
565+ const InputInfo& input,
566+ const InputInfo& mat2) {
567+ XLA_RETURN_IF_ERROR (CheckInputIs3DTensor (op, input));
568+ XLA_RETURN_IF_ERROR (CheckInputIs3DTensor (op, mat2));
569+
570+ if (input.tensor ->size (0 ) != mat2.tensor ->size (0 )) {
571+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
572+ op,
573+ " (): expected the size of the batch dimension (i.e. dimension 0) of `" ,
574+ input.name , " ` " , input.tensor ->shape ().get ().ToString (),
575+ " (batch dimension size: " , input.tensor ->size (0 ), " ), the " ,
576+ input.PositionAsStr (),
577+ " input tensor, to be the same as the size of the batch dimension of `" ,
578+ mat2.name , " ` " , mat2.tensor ->shape ().get ().ToString (),
579+ " (batch dimension size: " , mat2.tensor ->size (0 ), " ), the " ,
580+ mat2.PositionAsStr (), " input tensor." )));
581+ }
582+ if (input.tensor ->size (2 ) != mat2.tensor ->size (1 )) {
583+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
584+ op, " (): cannot apply batch matrix-multiplication to `" , input.name ,
585+ " ` " , input.tensor ->shape ().get ().ToString (), " , the " ,
586+ input.PositionAsStr (), " input tensor, and to `" , mat2.name , " ` " ,
587+ mat2.tensor ->shape ().get ().ToString (), " , the " , mat2.PositionAsStr (),
588+ " input tensor. Expected the size of dimension 2 of `" , input.name ,
589+ " ` (" , input.tensor ->size (2 ),
590+ " ) to be equal the size of dimension 1 of `" , mat2.name , " ` (" ,
591+ mat2.tensor ->size (1 ), " )." )));
592+ }
593+
594+ return absl::OkStatus ();
595+ }
596+
573597} // namespace
574598
575599// ////////////////////////////////////////////////////////////////////////////
@@ -1278,10 +1302,14 @@ XLATensorPtr avg_pool_nd_backward(const XLATensorPtr& out_backprop,
12781302 count_include_pad));
12791303}
12801304
1281- XLATensorPtr baddbmm (const XLATensorPtr& input, const XLATensorPtr& batch1,
1282- const XLATensorPtr& batch2, const at::Scalar& beta,
1283- const at::Scalar& alpha) {
1284- CheckBmmDimension (/* tag=*/ " baddbmm" , batch1, batch2);
1305+ absl::StatusOr<absl_nonnull XLATensorPtr> baddbmm (const XLATensorPtr& input,
1306+ const XLATensorPtr& batch1,
1307+ const XLATensorPtr& batch2,
1308+ const at::Scalar& beta,
1309+ const at::Scalar& alpha) {
1310+ XLA_RETURN_IF_ERROR (CheckBmmInputsAreValid (
1311+ " baddbmm" , {batch1, /* name= */ " batch1" , /* position= */ 2 },
1312+ {batch2, /* name= */ " batch2" , /* position= */ 3 }));
12851313 torch::lazy::Value product_multiplier =
12861314 XLAGraphExecutor::Get ()->GetIrValueForScalar (
12871315 alpha, batch1->shape ().get ().element_type (), batch1->GetDevice ());
@@ -1331,9 +1359,12 @@ XLATensorPtr bitwise_xor(const XLATensorPtr& input, const XLATensorPtr& other) {
13311359 input->GetIrValue (), other->GetIrValue ()));
13321360}
13331361
1334- XLATensorPtr bmm (const XLATensorPtr& batch1, const XLATensorPtr& batch2) {
1335- CheckBmmDimension (/* tag=*/ " bmm" , batch1, batch2);
1336- return matmul (batch1, batch2);
1362+ absl::StatusOr<absl_nonnull XLATensorPtr> bmm (const XLATensorPtr& input,
1363+ const XLATensorPtr& mat2) {
1364+ XLA_RETURN_IF_ERROR (CheckBmmInputsAreValid (
1365+ " bmm" , {input, /* name= */ " input" , /* position= */ 1 },
1366+ {mat2, /* name= */ " mat2" , /* position= */ 2 }));
1367+ return matmul (input, mat2);
13371368}
13381369
13391370std::vector<XLATensorPtr> broadcast_tensors (
0 commit comments