diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..8d571a7d5bd3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9532,6 +9532,32 @@ def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [ }]; } +def Torch_AtenHuberLossOp : Torch_Op<"aten.huber_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction, + Torch_FloatType:$delta + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenHuberLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenHuberLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 23f1814cc008..7ff675c54a8b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10721,6 +10721,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.huber_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.prim.Uninitialized : !torch.list\n" +" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.list) {\n" +" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %3 : !torch.list\n" +" } else {\n" +" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.list) {\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %6 : !torch.list\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield %0 : !torch.list\n" +" }\n" +" torch.prim.If.yield %5 : !torch.list\n" +" }\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -14612,6 +14637,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.huber_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cb49fa97b86a..6d5d99b03bbe 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10707,6 +10707,83 @@ class DecomposeAtenKlDivOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenHuberLossOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenHuberLossOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value target = op.getTarget(); + Value reductionValue = op.getReduction(); + Value deltaValue = op.getDelta(); + + auto selfTy = cast(self.getType()); + auto targetTy = cast(target.getType()); + auto outTy = cast(op.getType()); + if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "require self, target and output having sizes!"); + } + if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "require self, target and output having dtype!"); + } + + // Squared term: 0.5 * (input - target)^2 + Value constOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value constHalf = + rewriter.create(loc, rewriter.getF64FloatAttr(0.5)); + Value inputMinusTarget = + rewriter.create(loc, selfTy, self, target, constOne); + Value squaredValue = + rewriter.create(loc, selfTy, inputMinusTarget); + Value squaredTerm = + rewriter.create(loc, selfTy, squaredValue, constHalf); + + // Delta scaled term: delta * (|input - target| - 0.5 * delta) + Value absDiffValue = + rewriter.create(loc, selfTy, inputMinusTarget); + Value halfOfDelta = rewriter.create( + loc, rewriter.getType(), constHalf, deltaValue); + Value absDiffMinusDeltaHalf = rewriter.create( + loc, selfTy, absDiffValue, halfOfDelta, constOne); + Value deltaScaledTerm = rewriter.create( + loc, selfTy, absDiffMinusDeltaHalf, deltaValue); + + // Loss calculation based on the condition: |input - target| < delta + ValueTensorType boolTy = ValueTensorType::get( + op.getContext(), selfTy.getSizes(), rewriter.getI1Type()); + Value cmpValue = + rewriter.create(loc, boolTy, absDiffValue, deltaValue); + Value lossPointwise = rewriter.create( + loc, selfTy, cmpValue, squaredTerm, deltaScaledTerm); + + // Extract reduction int value from reduction argument + int64_t reduction; + if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) { + return rewriter.notifyMatchFailure(op, + "reduction should be a constant int!"); + } + Value loss; + Value none = rewriter.create(loc); + // reduction: mean + if (reduction == 1) { + loss = rewriter.create(loc, outTy, lossPointwise, none); + } else if (reduction == 2) { + // reduction: sum + loss = rewriter.create(loc, outTy, lossPointwise, none); + } else { + // reduction: none + loss = lossPointwise; + } + rewriter.replaceOp(op, loss); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenBinaryCrossEntropyWithLogitsOp : public OpRewritePattern { @@ -12696,6 +12773,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6d6ed9cad50d..a6956a53f64d 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -589,6 +589,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 73fdd937790a..2fa8bf2a2368 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3103,6 +3103,10 @@ "PoissonNLLLossSumReductionModule_basic", "PoissonNLLLossNonDefaultEpsModule_basic", "KlDivLossModule_batchmean_reduction_basic", + "HuberLossModule_default_basic", + "HuberLossModule_reduction_is_none_basic", + "HuberLossModule_mean_reduction_basic", + "HuberLossModule_sum_reduction_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", @@ -4682,6 +4686,10 @@ "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", "NllLossModule_sum_basic", + "HuberLossModule_default_basic", + "HuberLossModule_reduction_is_none_basic", + "HuberLossModule_mean_reduction_basic", + "HuberLossModule_sum_reduction_basic", "NormScalarComplexModule_basic", "NormScalarModule_basic", "NormScalarOptDimKeepDimComplexModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6af2292dea57..cdf429e2e578 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2185,6 +2185,14 @@ def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1 else: assert False, "Invalid reduction value." +def aten〇huber_loss〡shape(self: List[int], target: List[int], reduction: int = 1, delta: float = 1.) -> List[int]: + if reduction == 0: + return upstream_shape_functions.unary(self) + elif reduction in [1, 2]: + return [] + else: + assert False, "Invalid reduction value." + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -4571,6 +4579,14 @@ def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: T promoted_dtype = promote_dtypes(ranks, dtypes) return promoted_dtype +def aten〇huber_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, delta: float = 1.) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + return promoted_dtype + @check_dtype_function(_check_two_tensor_op( output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..84a17d64422f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -765,6 +765,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)" ) emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)") + emit("aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index bd82cd1c11b0..6f075136cce4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -63,3 +63,4 @@ def register_all_tests(): from . import meshgrid from . import timeout from . import kl_div_loss + from . import huber_loss diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/huber_loss.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/huber_loss.py new file mode 100644 index 000000000000..664adda3b0d2 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/huber_loss.py @@ -0,0 +1,108 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + + +class HuberLossModule_default(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.huber_loss(x, y) + + +@register_test_case(module_factory=lambda: HuberLossModule_default()) +def HuberLossModule_default_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class HuberLossModule_reduction_is_none(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.huber_loss(x, y, delta=2.3, reduction=0) + + +@register_test_case(module_factory=lambda: HuberLossModule_reduction_is_none()) +def HuberLossModule_reduction_is_none_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class HuberLossModule_mean_reduction(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.huber_loss(x, y, reduction=1) + + +@register_test_case(module_factory=lambda: HuberLossModule_mean_reduction()) +def HuberLossModule_mean_reduction_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ============================================================================== + + +class HuberLossModule_sum_reduction(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.huber_loss(x, y, reduction=2) + + +@register_test_case(module_factory=lambda: HuberLossModule_sum_reduction()) +def HuberLossModule_sum_reduction_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2)) + + +# ==============================================================================