Skip to content

Commit 8f0b483

Browse files
committed
Add huber loss function & decompose to primitive ops
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 8cc313f commit 8f0b483

File tree

8 files changed

+272
-0
lines changed

8 files changed

+272
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9532,6 +9532,32 @@ def Torch_AtenKlDivOp : Torch_Op<"aten.kl_div", [
95329532
}];
95339533
}
95349534

9535+
def Torch_AtenHuberLossOp : Torch_Op<"aten.huber_loss", [
9536+
AllowsTypeRefinement,
9537+
HasValueSemantics,
9538+
ReadOnly
9539+
]> {
9540+
let summary = "Generated op for `aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)`";
9541+
let arguments = (ins
9542+
AnyTorchTensorType:$self,
9543+
AnyTorchTensorType:$target,
9544+
Torch_IntType:$reduction,
9545+
Torch_FloatType:$delta
9546+
);
9547+
let results = (outs
9548+
AnyTorchOptionalTensorType:$result
9549+
);
9550+
let hasCustomAssemblyFormat = 1;
9551+
let extraClassDefinition = [{
9552+
ParseResult AtenHuberLossOp::parse(OpAsmParser &parser, OperationState &result) {
9553+
return parseDefaultTorchOp(parser, result, 4, 1);
9554+
}
9555+
void AtenHuberLossOp::print(OpAsmPrinter &printer) {
9556+
printDefaultTorchOp(printer, *this, 4, 1);
9557+
}
9558+
}];
9559+
}
9560+
95359561
def Torch_AtenBincountOp : Torch_Op<"aten.bincount", [
95369562
AllowsTypeRefinement,
95379563
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10717,6 +10717,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1071710717
" }\n"
1071810718
" return %2 : !torch.list<int>\n"
1071910719
" }\n"
10720+
" func.func @\"__torch_mlir_shape_fn.aten.huber_loss\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
10721+
" %none = torch.constant.none\n"
10722+
" %str = torch.constant.str \"AssertionError: Invalid reduction value.\"\n"
10723+
" %int0 = torch.constant.int 0\n"
10724+
" %int1 = torch.constant.int 1\n"
10725+
" %int2 = torch.constant.int 2\n"
10726+
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
10727+
" %1 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
10728+
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
10729+
" %3 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
10730+
" torch.prim.If.yield %3 : !torch.list<int>\n"
10731+
" } else {\n"
10732+
" %3 = torch.prim.ListConstruct %int1, %int2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
10733+
" %4 = torch.aten.__contains__.int_list %3, %arg2 : !torch.list<int>, !torch.int -> !torch.bool\n"
10734+
" %5 = torch.prim.If %4 -> (!torch.list<int>) {\n"
10735+
" %6 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
10736+
" torch.prim.If.yield %6 : !torch.list<int>\n"
10737+
" } else {\n"
10738+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
10739+
" torch.prim.If.yield %0 : !torch.list<int>\n"
10740+
" }\n"
10741+
" torch.prim.If.yield %5 : !torch.list<int>\n"
10742+
" }\n"
10743+
" return %2 : !torch.list<int>\n"
10744+
" }\n"
1072010745
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
1072110746
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
1072210747
" return %0 : !torch.tuple<list<int>, list<int>>\n"
@@ -14608,6 +14633,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1460814633
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1460914634
" return %4 : !torch.int\n"
1461014635
" }\n"
14636+
" func.func @\"__torch_mlir_dtype_fn.aten.huber_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.float) -> !torch.int {\n"
14637+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14638+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
14639+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
14640+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
14641+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
14642+
" return %4 : !torch.int\n"
14643+
" }\n"
1461114644
" func.func @\"__torch_mlir_dtype_fn.aten.mse_loss\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
1461214645
" %none = torch.constant.none\n"
1461314646
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10707,6 +10707,91 @@ class DecomposeAtenKlDivOp : public OpRewritePattern<AtenKlDivOp> {
1070710707
};
1070810708
} // namespace
1070910709

10710+
namespace {
10711+
class DecomposeAtenHuberLossOp : public OpRewritePattern<AtenHuberLossOp> {
10712+
using OpRewritePattern::OpRewritePattern;
10713+
LogicalResult matchAndRewrite(AtenHuberLossOp op,
10714+
PatternRewriter &rewriter) const override {
10715+
Location loc = op.getLoc();
10716+
Value self = op.getSelf();
10717+
Value target = op.getTarget();
10718+
Value reductionValue = op.getReduction();
10719+
Value deltaValue = op.getDelta();
10720+
10721+
auto selfTy = cast<ValueTensorType>(self.getType());
10722+
auto targetTy = cast<ValueTensorType>(target.getType());
10723+
auto outTy = cast<ValueTensorType>(op.getType());
10724+
if (!selfTy.hasSizes() || !targetTy.hasSizes() || !outTy.hasSizes()) {
10725+
return rewriter.notifyMatchFailure(
10726+
op, "require self, target and output having sizes!");
10727+
}
10728+
if (!selfTy.hasDtype() || !targetTy.hasDtype() || !outTy.hasDtype()) {
10729+
return rewriter.notifyMatchFailure(
10730+
op, "require self, target and output having dtype!");
10731+
}
10732+
10733+
// Extract delta float value from delta argument
10734+
// double delta;
10735+
// if (!matchPattern(deltaValue, m_TorchConstantFloat(&delta))) {
10736+
// return rewriter.notifyMatchFailure(op,
10737+
// "delta should be a constant
10738+
// float!");
10739+
// }
10740+
10741+
// Squared term: 0.5 * (input - target)^2
10742+
Value constOne =
10743+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
10744+
Value constHalf =
10745+
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.5));
10746+
Value inputMinusTarget =
10747+
rewriter.create<AtenSubTensorOp>(loc, selfTy, self, target, constOne);
10748+
Value squaredValue =
10749+
rewriter.create<AtenSquareOp>(loc, selfTy, inputMinusTarget);
10750+
Value squaredTerm =
10751+
rewriter.create<AtenMulScalarOp>(loc, selfTy, squaredValue, constHalf);
10752+
10753+
// Delta scaled term: delta * (|input - target| - 0.5 * delta)
10754+
Value absDiffValue =
10755+
rewriter.create<AtenAbsOp>(loc, selfTy, inputMinusTarget);
10756+
Value halfOfDelta = rewriter.create<AtenMulOp>(
10757+
loc, rewriter.getType<Torch::FloatType>(), constHalf, deltaValue);
10758+
Value absDiffMinusDeltaHalf = rewriter.create<AtenSubScalarOp>(
10759+
loc, selfTy, absDiffValue, halfOfDelta, constOne);
10760+
Value deltaScaledTerm = rewriter.create<AtenMulScalarOp>(
10761+
loc, selfTy, absDiffMinusDeltaHalf, deltaValue);
10762+
10763+
// Loss calculation based on the condition: |input - target| < delta
10764+
ValueTensorType boolTy = ValueTensorType::get(
10765+
op.getContext(), selfTy.getSizes(), rewriter.getI1Type());
10766+
Value cmpValue =
10767+
rewriter.create<AtenLeScalarOp>(loc, boolTy, absDiffValue, deltaValue);
10768+
Value lossPointwise = rewriter.create<AtenWhereSelfOp>(
10769+
loc, selfTy, cmpValue, squaredTerm, deltaScaledTerm);
10770+
10771+
// Extract reduction int value from reduction argument
10772+
int64_t reduction;
10773+
if (!matchPattern(reductionValue, m_TorchConstantInt(&reduction))) {
10774+
return rewriter.notifyMatchFailure(op,
10775+
"reduction should be a constant int!");
10776+
}
10777+
Value loss;
10778+
Value none = rewriter.create<ConstantNoneOp>(loc);
10779+
// reduction: mean
10780+
if (reduction == 1) {
10781+
loss = rewriter.create<AtenMeanOp>(loc, outTy, lossPointwise, none);
10782+
} else if (reduction == 2) {
10783+
// reduction: sum
10784+
loss = rewriter.create<AtenSumOp>(loc, outTy, lossPointwise, none);
10785+
} else {
10786+
// reduction: none
10787+
loss = lossPointwise;
10788+
}
10789+
rewriter.replaceOp(op, loss);
10790+
return success();
10791+
}
10792+
};
10793+
} // namespace
10794+
1071010795
namespace {
1071110796
class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1071210797
: public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12625,6 +12710,7 @@ class DecomposeComplexOpsPass
1262512710
addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1262612711
patterns);
1262712712
addPatternIfTargetOpIsIllegal<DecomposeAtenKlDivOp>(patterns);
12713+
addPatternIfTargetOpIsIllegal<DecomposeAtenHuberLossOp>(patterns);
1262812714
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
1262912715
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);
1263012716
addPatternIfTargetOpIsIllegal<DecomposeAtenArgsortOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
588588
target.addIllegalOp<AtenLogaddexpOp>();
589589
target.addIllegalOp<AtenLogaddexp2Op>();
590590
target.addIllegalOp<AtenKlDivOp>();
591+
target.addIllegalOp<AtenHuberLossOp>();
591592

592593
for (auto &opName : backendLegalOpsSet) {
593594
target.addLegalOp(

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,14 @@ def aten〇kl_div〡shape(self: List[int], target: List[int], reduction: int = 1
21822182
else:
21832183
assert False, "Invalid reduction value."
21842184

2185+
def aten〇huber_loss〡shape(self: List[int], target: List[int], reduction: int = 1, delta: float = 1.) -> List[int]:
2186+
if reduction == 0:
2187+
return upstream_shape_functions.unary(self)
2188+
elif reduction in [1, 2]:
2189+
return []
2190+
else:
2191+
assert False, "Invalid reduction value."
2192+
21852193
@check_shape_function([
21862194
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
21872195
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
@@ -4568,6 +4576,14 @@ def aten〇kl_div〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: T
45684576
promoted_dtype = promote_dtypes(ranks, dtypes)
45694577
return promoted_dtype
45704578

4579+
def aten〇huber_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1, delta: float = 1.) -> int:
4580+
self_rank, self_dtype = self_rank_dtype
4581+
target_rank, target_dtype = target_rank_dtype
4582+
ranks: List[Optional[int]] = [self_rank, target_rank]
4583+
dtypes = [self_dtype, target_dtype]
4584+
promoted_dtype = promote_dtypes(ranks, dtypes)
4585+
return promoted_dtype
4586+
45714587
@check_dtype_function(_check_two_tensor_op(
45724588
output_error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
45734589
def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@ def emit_with_mutating_variants(key, **kwargs):
765765
"aten::poisson_nll_loss : (Tensor, Tensor, bool, bool, float, int) -> (Tensor)"
766766
)
767767
emit("aten::kl_div : (Tensor, Tensor, int, bool) -> (Tensor)")
768+
emit("aten::huber_loss : (Tensor, Tensor, int, float) -> (Tensor)")
768769
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
769770
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
770771
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ def register_all_tests():
6363
from . import meshgrid
6464
from . import timeout
6565
from . import kl_div_loss
66+
from . import huber_loss
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
8+
from torch_mlir_e2e_test.framework import TestUtils
9+
from torch_mlir_e2e_test.registry import register_test_case
10+
from torch_mlir_e2e_test.annotations import annotate_args, export
11+
12+
# ==============================================================================
13+
14+
15+
class HuberLossModule_default(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
@export
20+
@annotate_args(
21+
[
22+
None,
23+
([-1, -1, -1], torch.float32, True),
24+
([-1, -1, -1], torch.float32, True),
25+
]
26+
)
27+
def forward(self, x, y):
28+
return torch.ops.aten.huber_loss(x, y)
29+
30+
31+
@register_test_case(module_factory=lambda: HuberLossModule_default())
32+
def HuberLossModule_default_basic(module, tu: TestUtils):
33+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
34+
35+
36+
# ==============================================================================
37+
38+
39+
class HuberLossModule_reduction_is_none(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
@export
44+
@annotate_args(
45+
[
46+
None,
47+
([-1, -1, -1], torch.float32, True),
48+
([-1, -1, -1], torch.float32, True),
49+
]
50+
)
51+
def forward(self, x, y):
52+
return torch.ops.aten.huber_loss(x, y, delta=2.3, reduction=0)
53+
54+
55+
@register_test_case(module_factory=lambda: HuberLossModule_reduction_is_none())
56+
def HuberLossModule_reduction_is_none_basic(module, tu: TestUtils):
57+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
58+
59+
60+
# ==============================================================================
61+
62+
63+
class HuberLossModule_mean_reduction(torch.nn.Module):
64+
def __init__(self):
65+
super().__init__()
66+
67+
@export
68+
@annotate_args(
69+
[
70+
None,
71+
([-1, -1, -1], torch.float32, True),
72+
([-1, -1, -1], torch.float32, True),
73+
]
74+
)
75+
def forward(self, x, y):
76+
return torch.ops.aten.huber_loss(x, y, reduction=1)
77+
78+
79+
@register_test_case(module_factory=lambda: HuberLossModule_mean_reduction())
80+
def HuberLossModule_mean_reduction_basic(module, tu: TestUtils):
81+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
82+
83+
84+
# ==============================================================================
85+
86+
87+
class HuberLossModule_sum_reduction(torch.nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
91+
@export
92+
@annotate_args(
93+
[
94+
None,
95+
([-1, -1, -1], torch.float32, True),
96+
([-1, -1, -1], torch.float32, True),
97+
]
98+
)
99+
def forward(self, x, y):
100+
return torch.ops.aten.huber_loss(x, y, reduction=2)
101+
102+
103+
@register_test_case(module_factory=lambda: HuberLossModule_sum_reduction())
104+
def HuberLossModule_sum_reduction_basic(module, tu: TestUtils):
105+
module.forward(tu.rand(3, 5, 2), tu.rand(3, 5, 2))
106+
107+
108+
# ==============================================================================

0 commit comments

Comments
 (0)