From fd51374f78defbfb948118651e76caeaab3d2e62 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Wed, 2 Jul 2025 17:58:54 -0400 Subject: [PATCH 1/9] Support decomposition of torch.broadcast_tensors --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 ++++ .../torch-mlir/Dialect/Torch/Utils/Utils.h | 4 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 6 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++++++- .../Transforms/LowerToBackendContract.cpp | 1 + lib/Dialect/Torch/Utils/Utils.cpp | 129 ++++++++++++------ projects/pt1/e2e_testing/xfail_sets.py | 2 + .../build_tools/abstract_interp_lib_gen.py | 19 +++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 51 +++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 20 +++ 11 files changed, 255 insertions(+), 53 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ebe4347a2aca..fc787d355d2a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11973,6 +11973,29 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ let hasFolder = 1; } +def Torch_AtenBroadcastTensorsOp : Torch_Op<"aten.broadcast_tensors", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::broadcast_tensors : (Tensor[]) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$tensors + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBroadcastTensorsOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenBroadcastTensorsOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a000b7ab2f98..cd93af1fd8d4 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -60,10 +60,10 @@ Type getBuiltInTypeForTorchScalar(Type type); Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype); -// Checks whether the `inputA` and `inputB` are broadcast compatible or not. If +// Checks whether the inputs are broadcast compatible or not. If // yes, then computes the final broadcast shape. void computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, + SmallVector inputs, SmallVector &resultShape, SmallVector &resultShapeValue); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index fbaf8a1f756b..d2e3c94733e9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1065,9 +1065,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } else { SmallVector resultBroadcastShapeInt; SmallVector resultBroadcastShapeValue; - Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr, - valList[i], resultBroadcastShapeInt, - resultBroadcastShapeValue); + Torch::computeBroadcastShape( + rewriter, binder.getLoc(), {curr, valList[i]}, + resultBroadcastShapeInt, resultBroadcastShapeValue); auto baseType = Torch::ValueTensorType::get( binder.op->getContext(), resultBroadcastShapeInt, resultType.getOptionalDtype()); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 16b8ee2ebca5..210b6c918c1d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -24,7 +24,6 @@ #include "llvm/ADT/StringSet.h" #include #include - using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -3415,7 +3414,7 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { // calculate common shape for broadcast SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, self, other, broadcastShape, + computeBroadcastShape(rewriter, loc, {self, other}, broadcastShape, broadcastShapeValue); Type broadcastType = ValueTensorType::get( @@ -8962,7 +8961,7 @@ class DecomposeAtenCosineSimilarityOp // Broadcast x1 and x2 to the same shape SmallVector indexBroadcastShapeInt; SmallVector indexBroadcastShapeValue; - computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt, + computeBroadcastShape(rewriter, loc, {x1, x2}, indexBroadcastShapeInt, indexBroadcastShapeValue); Type dtype = cast(x1.getType()).getOptionalDtype(); Type broadcastType = ValueTensorType::get( @@ -12203,6 +12202,52 @@ class DecomposeAtenRoundDecimalsOp }; } // namespace +namespace { +class DecomposeAtenBroadcastTensorsOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBroadcastTensorsOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + SmallVector tensors; + if (!getListConstructElements(op.getTensors(), tensors)) + return rewriter.notifyMatchFailure(op, "Unable to get tensors"); + int64_t numTensors = tensors.size(); + + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + + computeBroadcastShape(rewriter, loc, tensors, broadcastShape, + broadcastShapeValue); + + auto resType = cast(tensors[0].getType()); + auto dtype = resType.getDtype(); + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), dtype); + + Value broadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + + SmallVector broadcastedValues; + for (int64_t i = 0; i < numTensors; i++) { + auto inputTensor = tensors[i]; + auto broadcastedVal = rewriter.create( + loc, broadcastType, inputTensor, broadcastShapeTorchList); + broadcastedValues.push_back(broadcastedVal); + } + + auto broadcastedValuesList = rewriter.create( + loc, Torch::ListType::get(broadcastType), broadcastedValues); + + rewriter.replaceOp(op, broadcastedValuesList); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -12403,6 +12448,7 @@ class DecomposeComplexOpsPass DecomposeAtenAdaptivePool2dOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAdaptivePool2dOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index dac4721c7772..a6cf02c88299 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -518,6 +518,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388e31353571..f7ab6c7fe53b 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -479,78 +479,117 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, return unsqueezed; } -// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If +// Checks whether the inputs are broadcast compatible or not. If // yes, then computes the final broadcast shape. void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, + SmallVector inputs, SmallVector &resultShape, SmallVector &resultShapeValue) { - SmallVector shapeA{ - cast(inputA.getType()).getSizes()}; - SmallVector shapeB{ - cast(inputB.getType()).getSizes()}; - unsigned rankA = shapeA.size(); - unsigned rankB = shapeB.size(); - unsigned minRank = rankA > rankB ? rankB : rankA; + + SmallVector> shapes; + SmallVector ranks; + + for (auto input : inputs) { + SmallVector shape{ + cast(input.getType()).getSizes()}; + shapes.push_back(shape); + ranks.push_back(shape.size()); + } + + unsigned maxRank = *std::max_element(ranks.begin(), ranks.end()); + // Check whether the shapes of the tensors are broadcastable or not. // Two tensors are “broadcastable” if the following rules hold: // 1.) Each tensor has at least one dimension. // 2.) When iterating over the dimension sizes, starting at the trailing // dimension, the dimension sizes must either be equal, one of them is 1, or // one of them does not exist. - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); + for (unsigned i = 0; i < maxRank; i++) { + + SmallVector sizeInputs; + for (auto [idx, input] : llvm::enumerate(inputs)) { + int sizeDimIdx = ranks[idx] - i - 1; + if (sizeDimIdx >= 0) { + auto sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sizeDimIdx)); + sizeInputs.push_back( + rewriter.createOrFold(loc, input, sizeDim)); + } + } + Value torchCstOne = rewriter.create( loc, rewriter.getI64IntegerAttr(1)); - Value cmpSizeAEqualsSizeB = - rewriter.create(loc, sizeInputA, sizeInputB); - Value cmpSizeAEqualsOne = - rewriter.create(loc, sizeInputA, torchCstOne); - Value cmpSizeBEqualsOne = - rewriter.create(loc, sizeInputB, torchCstOne); + SmallVector predicates; + for (auto sizeVal : sizeInputs) { + Value cmpSizeEquals = + rewriter.create(loc, sizeVal, sizeInputs.front()); + predicates.push_back(cmpSizeEquals); + Value cmpSizeEqualsOne = + rewriter.create(loc, sizeVal, torchCstOne); + predicates.push_back(cmpSizeEqualsOne); + } + Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(cmpSizeAEqualsOne.getType()), - SmallVector{cmpSizeAEqualsSizeB, cmpSizeAEqualsOne, - cmpSizeBEqualsOne}); + loc, Torch::ListType::get(predicates.front().getType()), predicates); Value cmp = rewriter.create(loc, anyBoolOpList); rewriter.create( loc, cmp, "tensors are not broadcast compatible"); } + // If we reach here then it means both the shapes are broadcast compatible. - resultShape = rankA >= rankB ? shapeA : shapeB; - Value shapeTensor = rankA >= rankB ? inputA : inputB; + auto maxRankIdx = + std::max_element(ranks.begin(), ranks.end()) - ranks.begin(); + resultShape = shapes[maxRankIdx]; + Value shapeTensor = inputs[maxRankIdx]; + for (unsigned i = 0; i < resultShape.size(); i++) { Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); resultShapeValue.push_back( rewriter.createOrFold(loc, shapeTensor, sizeDim)); } - unsigned resultRank = resultShape.size(); - for (unsigned i = 0; i < minRank; i++) { - Value sizeDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankA - i - 1)); - Value sizeDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(rankB - i - 1)); - Value sizeInputA = - rewriter.createOrFold(loc, inputA, sizeDimA); - Value sizeInputB = - rewriter.createOrFold(loc, inputB, sizeDimB); - resultShapeValue[resultRank - i - 1] = - rewriter.create(loc, sizeInputA, sizeInputB); - if (shapeA[rankA - i - 1] == kUnknownSize || - shapeB[rankB - i - 1] == kUnknownSize) { + for (unsigned i = 0; i < maxRank; i++) { + + SmallVector sizeInputs; + for (auto [idx, input] : llvm::enumerate(inputs)) { + int sizeDimIdx = ranks[idx] - i - 1; + if (sizeDimIdx >= 0) { + auto sizeDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sizeDimIdx)); + sizeInputs.push_back( + rewriter.createOrFold(loc, input, sizeDim)); + } + } + + // Compute shape value of broadcast result, + // which is the maximum of dimension sizes across all inputs + Value maxShapeVal = sizeInputs.front(); + for (auto sizeInput : sizeInputs) { + maxShapeVal = rewriter.create(loc, maxShapeVal, sizeInput); + } + resultShapeValue[resultRank - i - 1] = maxShapeVal; + + // Compute result shape if all input shapes are known + bool unknownSize = false; + for (auto [idx, shape] : llvm::enumerate(shapes)) { + if (ranks[idx] - i - 1 < shape.size() && + shape[ranks[idx] - i - 1] == kUnknownSize) { + unknownSize = true; + } + } + + if (unknownSize) { resultShape[resultRank - i - 1] = kUnknownSize; } else { - resultShape[resultRank - i - 1] = - std::max(shapeA[rankA - i - 1], shapeB[rankB - i - 1]); + + int64_t maxShape = 1; + for (auto [idx, shape] : llvm::enumerate(shapes)) { + if (ranks[idx] - i - 1 < shape.size()) { + maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]); + } + } + resultShape[resultRank - i - 1] = maxShape; } } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e7833fd9ac33..605b28a5ac40 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -4205,6 +4205,8 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "BroadcastToModule_basic", + "BroadcastTensorsModule_basic", + "BroadcastTensorsModuleList_multiple_ranks", "BucketizeTensorFloatModule_basic", "BucketizeTensorModule_basic", "BucketizeTensorOutInt32RightModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 50ea52abdba9..a5b52784e4ed 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -949,6 +949,10 @@ def aten〇expand_as〡shape(self: List[int], other: List[int]) -> List[int]: def aten〇broadcast_to〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.expand(self, size) +def aten〇broadcast_tensors〡shape(tensors: List[List[int]]) -> List[List[int]]: + out_shape: torch.Size = upstream_shape_functions.broadcast_shapes(tensors) + return out_shape + def aten〇view〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.view(self, size) @@ -3127,6 +3131,21 @@ def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), + Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), + Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), + NonZeroDTensorWithDtype(torch.complex64)])]) +def aten〇broadcast_tensors〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: + ranks: List[Optional[int]] = [] + dtypes: List[int] = [] + assert len(tensors_rank_dtype) != 0 + for tensor_rank_dtype in tensors_rank_dtype: + tensor_rank, tensor_dtype = tensor_rank_dtype + ranks.append(tensor_rank) + dtypes.append(tensor_dtype) + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=2,dim=0, error_types={torch.complex128, torch.complex64, *all_integer_dtypes()})) def aten〇cosine_similarity〡dtype(x1_rank_dtype: Tuple[int, int], x2_rank_dtype: Tuple[int, int], dim: int = 1, eps: float = 1e-08) -> int: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6a173877b0b0..0290071a55da 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -895,6 +895,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_as : (Tensor, Tensor) -> (Tensor)") emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True) + emit("aten::broadcast_tensors : (Tensor[]) -> (Tensor[])") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..ad506a6e6cdd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2146,6 +2146,57 @@ def BroadcastToModule_basic(module, tu: TestUtils): # ============================================================================== +class BroadcastTensorsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + x1, y1 = torch.broadcast_tensors(x, y) + return x1, y1 + + +@register_test_case(module_factory=lambda: BroadcastTensorsModule()) +def BroadcastTensorsModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3), tu.rand(2, 1)) + + +# ============================================================================== + + +class BroadcastTensorsModuleList(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3], torch.float32, True), + ([2, 1], torch.float32, True), + ([2, 1, 1], torch.float32, True), + ] + ) + def forward(self, x, y, z): + x1, y1, z1 = torch.broadcast_tensors(x, y, z) + return x1, y1, z1 + + +@register_test_case(module_factory=lambda: BroadcastTensorsModuleList()) +def BroadcastTensorsModuleList_multiple_ranks(module, tu: TestUtils): + module.forward(tu.rand(3), tu.rand(2, 1), tu.rand(2, 1, 1)) + + +# ============================================================================== + + class BroadcastToSameRankStaticModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 7644c00de069..fe6a2ac7b7fd 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -846,3 +846,23 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf %result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],bf16>, !torch.list, !torch.vtensor<[96],bf16>, !torch.vtensor<[96],bf16>, !torch.float -> !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_tensors( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[2,1],f32>) -> !torch.list> +// CHECK: %[[VAR1:.*]] = torch.constant.int 2 +// CHECK: %[[VAR2:.*]] = torch.constant.int 3 +// CHECK: %[[VAR3:.*]] = torch.constant.bool true +// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible" +// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible" +// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[VAR1]], %[[VAR2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[ARG0:.*]], %[[VAR4]] : !torch.vtensor<[1,3],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[ARG1:.*]], %[[VAR4]] : !torch.vtensor<[2,1],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[VAR7:.*]] = torch.prim.ListConstruct %[[VAR5]], %[[VAR6]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> +// CHECK: return %[[VAR7]] : !torch.list> +func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list> { + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list + %1 = torch.aten.broadcast_tensors %0 : !torch.list -> !torch.list> + return %1 : !torch.list> +} From dbef8f452a1e6bd8ef587c13a17ef3aa9ac598c9 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Wed, 2 Jul 2025 23:47:52 -0400 Subject: [PATCH 2/9] Fix shape and dtype functions and add test to expected crashing set --- .../Transforms/AbstractInterpLibrary.cpp | 60 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 31 +++++----- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fc65f7f1653a..cfc324207e1f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7796,6 +7796,37 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.expand(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.broadcast_tensors\"(%arg0: !torch.list>) -> !torch.list> {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list>) {\n" +" %3 = torch.prim.ListConstruct : () -> !torch.list>\n" +" torch.prim.If.yield %3 : !torch.list>\n" +" } else {\n" +" %3 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list>, !torch.int -> !torch.list\n" +" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %5 = torch.aten.__range_length %int1, %4, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %6 = torch.prim.Loop %5, %true, init(%3) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.list):\n" +" %9 = torch.aten.__derive_index %arg1, %int1, %int1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" +" %10 = torch.aten.__getitem__.t %arg0, %9 : !torch.list>, !torch.int -> !torch.list\n" +" %11 = func.call @__torch__.torch.jit._shape_functions.broadcast(%arg2, %10) : (!torch.list, !torch.list) -> !torch.list\n" +" torch.prim.Loop.condition %true, iter(%11 : !torch.list)\n" +" } : (!torch.int, !torch.bool, !torch.list) -> !torch.list\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %8 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %9 = torch.aten.append.t %7, %6 : !torch.list>, !torch.list -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" torch.prim.If.yield %7 : !torch.list>\n" +" }\n" +" return %2 : !torch.list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.view\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.view(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12407,6 +12438,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.broadcast_tensors\"(%arg0: !torch.list>) -> !torch.list> {\n" +" %true = torch.constant.bool true\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" %1 = torch.prim.Loop %0, %true, init(%int0) {\n" +" ^bb0(%arg1: !torch.int, %arg2: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %5 = torch.prim.TupleIndex %4, %int0 : !torch.tuple, !torch.int -> !torch.int\n" +" %6 = torch.aten.gt.int %5, %arg2 : !torch.int, !torch.int -> !torch.bool\n" +" %7 = torch.prim.If %6 -> (!torch.int) {\n" +" %8 = torch.prim.TupleIndex %4, %int0 : !torch.tuple, !torch.int -> !torch.int\n" +" torch.prim.If.yield %8 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %arg2 : !torch.int\n" +" }\n" +" torch.prim.Loop.condition %true, iter(%7 : !torch.int)\n" +" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list>\n" +" %3 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg1: !torch.int):\n" +" %4 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" +" %5:2 = torch.prim.TupleUnpack %4 : !torch.tuple -> !torch.int, !torch.int\n" +" %6 = torch.prim.TupleConstruct %1, %5#1 : !torch.int, !torch.int -> !torch.tuple\n" +" %7 = torch.aten.append.t %2, %6 : !torch.list>, !torch.tuple -> !torch.list>\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" return %2 : !torch.list>\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cosine_similarity\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n" " %int7 = torch.constant.int 7\n" " %int6 = torch.constant.int 6\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 605b28a5ac40..b4cd874769f7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3446,6 +3446,7 @@ "StdCorrectionEmptyDimModule_basic", "VarCorrectionEmptyDimModule_basic", "VarDimEmptyDimModule_basic", + "BroadcastTensorsModule_basic", # Runtime op verification: rank mismatch in memref.cast "ViewSizeFromOtherTensor_basic", "SliceOutOfLowerBoundEndIndexModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a5b52784e4ed..587f53397271 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -950,8 +950,15 @@ def aten〇broadcast_to〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.expand(self, size) def aten〇broadcast_tensors〡shape(tensors: List[List[int]]) -> List[List[int]]: - out_shape: torch.Size = upstream_shape_functions.broadcast_shapes(tensors) - return out_shape + if len(tensors) == 0: + return [] + result = tensors[0] + for i in range(1, len(tensors)): + result = upstream_shape_functions.broadcast(result, tensors[i]) + out: List[List[int]] = [] + for _ in tensors: + out.append(result) + return out def aten〇view〡shape(self: List[int], size: List[int]) -> List[int]: return upstream_shape_functions.view(self, size) @@ -3131,20 +3138,16 @@ def aten〇broadcast_to〡dtype(self_rank_dtype: Tuple[int, int], size: List[int self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function( - [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), - Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), - Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), - NonZeroDTensorWithDtype(torch.complex64)])]) -def aten〇broadcast_tensors〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: - ranks: List[Optional[int]] = [] - dtypes: List[int] = [] - assert len(tensors_rank_dtype) != 0 +def aten〇broadcast_tensors〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + max_rank = 0 + for rd in tensors_rank_dtype: + if rd[0] > max_rank: + max_rank = rd[0] + out: List[Tuple[int, int]] = [] for tensor_rank_dtype in tensors_rank_dtype: tensor_rank, tensor_dtype = tensor_rank_dtype - ranks.append(tensor_rank) - dtypes.append(tensor_dtype) - return promote_dtypes(ranks, dtypes) + out.append((max_rank, tensor_dtype)) + return out @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=2,dim=0, error_types={torch.complex128, torch.complex64, *all_integer_dtypes()})) From 8ae8aee9af1e504d4065f30b7221602fcc029ca0 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Fri, 11 Jul 2025 13:03:53 -0400 Subject: [PATCH 3/9] Refactor computeBroadcastShape and update unit test --- lib/Dialect/Torch/Utils/Utils.cpp | 53 +++++++++---------- test/Dialect/Torch/decompose-complex-ops.mlir | 29 +++++----- 2 files changed, 42 insertions(+), 40 deletions(-) diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index f7ab6c7fe53b..bc11f10f8fce 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -488,6 +488,7 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, SmallVector> shapes; SmallVector ranks; + SmallVector maxShapeValues; for (auto input : inputs) { SmallVector shape{ @@ -496,6 +497,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, ranks.push_back(shape.size()); } + Value torchCstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); unsigned maxRank = *std::max_element(ranks.begin(), ranks.end()); // Check whether the shapes of the tensors are broadcastable or not. @@ -517,23 +520,34 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, } } - Value torchCstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + // Compute shape value of broadcast result, + // which is the maximum of dimension sizes across all inputs + Value maxShapeVal = sizeInputs.front(); + for (auto sizeInput : sizeInputs) { + maxShapeVal = rewriter.create(loc, maxShapeVal, sizeInput); + } + maxShapeValues.push_back(maxShapeVal); + SmallVector predicates; for (auto sizeVal : sizeInputs) { Value cmpSizeEquals = - rewriter.create(loc, sizeVal, sizeInputs.front()); - predicates.push_back(cmpSizeEquals); + rewriter.create(loc, sizeVal, maxShapeVal); Value cmpSizeEqualsOne = rewriter.create(loc, sizeVal, torchCstOne); - predicates.push_back(cmpSizeEqualsOne); + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(cmpSizeEquals.getType()), + SmallVector{cmpSizeEquals, cmpSizeEqualsOne}); + Value cmp = rewriter.create(loc, anyBoolOpList); + predicates.push_back(cmp); } - Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(predicates.front().getType()), predicates); - Value cmp = rewriter.create(loc, anyBoolOpList); - rewriter.create( - loc, cmp, "tensors are not broadcast compatible"); + if (!predicates.empty()) { + Value anyBoolOpList = rewriter.create( + loc, Torch::ListType::get(predicates.front().getType()), predicates); + Value cmp = rewriter.create(loc, anyBoolOpList); + rewriter.create( + loc, cmp, "tensors are not broadcast compatible"); + } } // If we reach here then it means both the shapes are broadcast compatible. @@ -551,24 +565,7 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, unsigned resultRank = resultShape.size(); for (unsigned i = 0; i < maxRank; i++) { - SmallVector sizeInputs; - for (auto [idx, input] : llvm::enumerate(inputs)) { - int sizeDimIdx = ranks[idx] - i - 1; - if (sizeDimIdx >= 0) { - auto sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(sizeDimIdx)); - sizeInputs.push_back( - rewriter.createOrFold(loc, input, sizeDim)); - } - } - - // Compute shape value of broadcast result, - // which is the maximum of dimension sizes across all inputs - Value maxShapeVal = sizeInputs.front(); - for (auto sizeInput : sizeInputs) { - maxShapeVal = rewriter.create(loc, maxShapeVal, sizeInput); - } - resultShapeValue[resultRank - i - 1] = maxShapeVal; + resultShapeValue[resultRank - i - 1] = maxShapeValues[i]; // Compute result shape if all input shapes are known bool unknownSize = false; diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index fe6a2ac7b7fd..dc589bffe835 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -849,18 +849,23 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf // ----- -// CHECK-LABEL: func.func @torch.aten.broadcast_tensors( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[2,1],f32>) -> !torch.list> -// CHECK: %[[VAR1:.*]] = torch.constant.int 2 -// CHECK: %[[VAR2:.*]] = torch.constant.int 3 -// CHECK: %[[VAR3:.*]] = torch.constant.bool true -// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible" -// CHECK: torch.runtime.assert %[[VAR3]], "tensors are not broadcast compatible" -// CHECK: %[[VAR4:.*]] = torch.prim.ListConstruct %[[VAR1]], %[[VAR2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAR5:.*]] = torch.aten.broadcast_to %[[ARG0:.*]], %[[VAR4]] : !torch.vtensor<[1,3],f32>, !torch.list -> !torch.vtensor<[2,3],f32> -// CHECK: %[[VAR6:.*]] = torch.aten.broadcast_to %[[ARG1:.*]], %[[VAR4]] : !torch.vtensor<[2,1],f32>, !torch.list -> !torch.vtensor<[2,3],f32> -// CHECK: %[[VAR7:.*]] = torch.prim.ListConstruct %[[VAR5]], %[[VAR6]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> -// CHECK: return %[[VAR7]] : !torch.list> +// CHECK-LABEL: func.func @torch.aten.broadcast_tensors +// CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32> +// CHECK: %[[INT2:.*]] = torch.constant.int 2 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[TRUE1:.*]] = torch.constant.bool true +// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[TRUE1]], %[[TRUE1]] : (!torch.bool, !torch.bool) -> !torch.list +// CHECK: %[[ALL1:.*]] = torch.aten.all.bool %[[LIST1]] : !torch.list -> !torch.bool +// CHECK: torch.runtime.assert %[[ALL1]], "tensors are not broadcast compatible" +// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[TRUE1]], %[[TRUE1]] : (!torch.bool, !torch.bool) -> !torch.list +// CHECK: %[[ALL2:.*]] = torch.aten.all.bool %[[LIST2]] : !torch.list -> !torch.bool +// CHECK: torch.runtime.assert %[[ALL2]], "tensors are not broadcast compatible" +// CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[B0:.*]] = torch.aten.broadcast_to %[[ARG0]], %[[SHAPE]] : !torch.vtensor<[1,3],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[B1:.*]] = torch.aten.broadcast_to %[[ARG1]], %[[SHAPE]] : !torch.vtensor<[2,1],f32>, !torch.list -> !torch.vtensor<[2,3],f32> +// CHECK: %[[OUTLIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> +// CHECK: return %[[OUTLIST]] : !torch.list> func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list> { %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list %1 = torch.aten.broadcast_tensors %0 : !torch.list -> !torch.list> From 45a574abac9a2264af4d0e96b4e2795ae988e303 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Fri, 11 Jul 2025 14:30:36 -0400 Subject: [PATCH 4/9] Resolve merge conflict --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 806600004fbe..b75b227adde1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11328,7 +11328,7 @@ class DecomposeAtenHeaviside : public OpRewritePattern { auto resultTy = dyn_cast(op.getType()); SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, input, value, broadcastShape, + computeBroadcastShape(rewriter, loc, {input, value}, broadcastShape, broadcastShapeValue); auto broadcastType = ValueTensorType::get( From 7d2831c485453ac9cd2fb16038f16d679d33f5b7 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Fri, 18 Jul 2025 12:33:42 -0400 Subject: [PATCH 5/9] Add folder for aten.all.bool --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 18 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 10 ++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 18 +++++++----------- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5da2a1b621f8..65ec383410b7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11041,6 +11041,7 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e50be5ff97ae..7f1490a25a01 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2838,6 +2838,24 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenAllBoolOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) { + auto inputConstruct = getSelf().getDefiningOp(); + if (!inputConstruct || isListPotentiallyMutated(inputConstruct)) + return nullptr; + // If all operands are a constant true, return true. + for (auto operand : inputConstruct.getOperands()) { + bool b = true; + if (!matchPattern(operand, m_TorchConstantBool(&b)) || !b) { + return nullptr; + } + } + return getI1IntegerAttr(getContext(), true); +} + //===----------------------------------------------------------------------===// // AtenFloatScalarOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7ab7c817b7ae..9bfc574b8393 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -847,7 +847,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::isneginf : (Tensor) -> (Tensor)") emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") - emit("aten::all.bool : (bool[]) -> (bool)") + emit("aten::all.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a025ec09726d..0a529fee772f 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2685,6 +2685,16 @@ func.func @torch.aten.any.bool$fold() -> !torch.bool { return %0 : !torch.bool } +// CHECK-LABEL: func.func @torch.aten.all.bool$fold() -> !torch.bool { +// CHECK: %[[CST_TRUE:.*]] = torch.constant.bool true +// CHECK: return %[[CST_TRUE]] : !torch.bool +func.func @torch.aten.all.bool$fold() -> !torch.bool { + %true = torch.constant.bool true + %input = torch.prim.ListConstruct %true, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list + %0 = torch.aten.all.bool %input : !torch.list -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func.func @torch.aten.floor$canonicalize // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],si64> // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?,?],si64> diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index dc589bffe835..3e921247cdc2 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -852,20 +852,16 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf // CHECK-LABEL: func.func @torch.aten.broadcast_tensors // CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32> -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[TRUE1:.*]] = torch.constant.bool true -// CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[TRUE1]], %[[TRUE1]] : (!torch.bool, !torch.bool) -> !torch.list -// CHECK: %[[ALL1:.*]] = torch.aten.all.bool %[[LIST1]] : !torch.list -> !torch.bool -// CHECK: torch.runtime.assert %[[ALL1]], "tensors are not broadcast compatible" -// CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[TRUE1]], %[[TRUE1]] : (!torch.bool, !torch.bool) -> !torch.list -// CHECK: %[[ALL2:.*]] = torch.aten.all.bool %[[LIST2]] : !torch.list -> !torch.bool -// CHECK: torch.runtime.assert %[[ALL2]], "tensors are not broadcast compatible" +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true +// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible" +// CHECK: torch.runtime.assert %[[TRUE]], "tensors are not broadcast compatible" // CHECK: %[[SHAPE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[B0:.*]] = torch.aten.broadcast_to %[[ARG0]], %[[SHAPE]] : !torch.vtensor<[1,3],f32>, !torch.list -> !torch.vtensor<[2,3],f32> // CHECK: %[[B1:.*]] = torch.aten.broadcast_to %[[ARG1]], %[[SHAPE]] : !torch.vtensor<[2,1],f32>, !torch.list -> !torch.vtensor<[2,3],f32> -// CHECK: %[[OUTLIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> -// CHECK: return %[[OUTLIST]] : !torch.list> +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[B0]], %[[B1]] : (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>) -> !torch.list> +// CHECK: return %[[LIST]] : !torch.list> func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: !torch.vtensor<[2,1],f32>) -> !torch.list> { %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list %1 = torch.aten.broadcast_tensors %0 : !torch.list -> !torch.list> From f6a2cf800a70294b6ae02924300e860b0daa4f5b Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Wed, 13 Aug 2025 10:45:55 -0400 Subject: [PATCH 6/9] Refactor implementation to avoid redundant computations --- lib/Dialect/Torch/IR/TorchOps.cpp | 2 +- lib/Dialect/Torch/Utils/Utils.cpp | 25 +++++++++++-------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 7f1490a25a01..4db9e47b836e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2848,7 +2848,7 @@ OpFoldResult AtenAllBoolOp::fold(FoldAdaptor adaptor) { return nullptr; // If all operands are a constant true, return true. for (auto operand : inputConstruct.getOperands()) { - bool b = true; + bool b; if (!matchPattern(operand, m_TorchConstantBool(&b)) || !b) { return nullptr; } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index bc11f10f8fce..0012c40b53c3 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -499,7 +499,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, Value torchCstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - unsigned maxRank = *std::max_element(ranks.begin(), ranks.end()); + auto maxRankItr = std::max_element(ranks.begin(), ranks.end()); + unsigned maxRank = *maxRankItr; // Check whether the shapes of the tensors are broadcastable or not. // Two tensors are “broadcastable” if the following rules hold: @@ -550,9 +551,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, } } - // If we reach here then it means both the shapes are broadcast compatible. - auto maxRankIdx = - std::max_element(ranks.begin(), ranks.end()) - ranks.begin(); + // If we reach here then it means all shapes are broadcast compatible. + auto maxRankIdx = maxRankItr - ranks.begin(); resultShape = shapes[maxRankIdx]; Value shapeTensor = inputs[maxRankIdx]; @@ -569,23 +569,20 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, // Compute result shape if all input shapes are known bool unknownSize = false; + int64_t maxShape = 1; for (auto [idx, shape] : llvm::enumerate(shapes)) { - if (ranks[idx] - i - 1 < shape.size() && - shape[ranks[idx] - i - 1] == kUnknownSize) { - unknownSize = true; + if (ranks[idx] - i - 1 < shape.size()) { + if (shape[ranks[idx] - i - 1] == kUnknownSize) { + unknownSize = true; + } else { + maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]); + } } } if (unknownSize) { resultShape[resultRank - i - 1] = kUnknownSize; } else { - - int64_t maxShape = 1; - for (auto [idx, shape] : llvm::enumerate(shapes)) { - if (ranks[idx] - i - 1 < shape.size()) { - maxShape = std::max(maxShape, shape[ranks[idx] - i - 1]); - } - } resultShape[resultRank - i - 1] = maxShape; } } From 8c2b76287cac5a9b4134e0df70f7a7b8f3ed205c Mon Sep 17 00:00:00 2001 From: vinitdeodhar Date: Wed, 13 Aug 2025 11:20:51 -0400 Subject: [PATCH 7/9] Resolve merge error for DecomposeComplexOps.cpp --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a295c35ac12f..6d4f4dacaeb5 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12620,6 +12620,7 @@ class DecomposeAtenBroadcastTensorsOp loc, Torch::ListType::get(broadcastType), broadcastedValues); rewriter.replaceOp(op, broadcastedValuesList); + return success(); } }; } // namespace From bd69b6ef4af3a9ebef6da96480e60a4b580dd312 Mon Sep 17 00:00:00 2001 From: vinitdeodhar Date: Wed, 13 Aug 2025 11:33:13 -0400 Subject: [PATCH 8/9] Resolve merge error #2 for DecomposeComplexOps.cpp --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6d4f4dacaeb5..2f33c100381b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12758,7 +12758,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { // calculate common shape for broadcast SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, finalIndices, index, broadcastShape, + computeBroadcastShape(rewriter, loc, {finalIndices, index}, broadcastShape, broadcastShapeValue); Type broadcastType = ValueTensorType::get( context, llvm::ArrayRef(broadcastShape), si64Type); From f3c0f5c69920c12f8497a73e072803a1b7730d0c Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Wed, 13 Aug 2025 14:47:05 -0400 Subject: [PATCH 9/9] Resolve test failure due to bad merge --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 4 ++-- test/Dialect/Torch/decompose-complex-ops.mlir | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2f33c100381b..68d716e8f004 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -12758,8 +12758,8 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { // calculate common shape for broadcast SmallVector broadcastShape; SmallVector broadcastShapeValue; - computeBroadcastShape(rewriter, loc, {finalIndices, index}, broadcastShape, - broadcastShapeValue); + computeBroadcastShape(rewriter, loc, {finalIndices, index}, + broadcastShape, broadcastShapeValue); Type broadcastType = ValueTensorType::get( context, llvm::ArrayRef(broadcastShape), si64Type); diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index c01313abc663..38d0e97d6e33 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -849,6 +849,7 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf // ----- + // CHECK-LABEL: func.func @torch.aten.broadcast_tensors // CHECK-SAME: (%[[ARG0:.*]]: !torch.vtensor<[1,3],f32> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[2,1],f32> @@ -866,6 +867,7 @@ func.func @torch.aten.broadcast_tensors(%arg0: !torch.vtensor<[1,3],f32>, %arg1: %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[1,3],f32>, !torch.vtensor<[2,1],f32>) -> !torch.list %1 = torch.aten.broadcast_tensors %0 : !torch.list -> !torch.list> return %1 : !torch.list> +} // -----