Skip to content

Commit fb80dc1

Browse files
committed
Add MeanVarNorm lit tests
Signed-off-by: Zahid Wakeel <zahid.wakeel@multicorewareinc.com>
1 parent 75bbd1c commit fb80dc1

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,57 @@ func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch
15951595

15961596
// -----
15971597

1598+
// CHECK-LABEL: func.func @test_meanvarnorm(
1599+
func.func @test_meanvarnorm(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1600+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1601+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1602+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1603+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1604+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 0
1605+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
1606+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 3
1607+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
1608+
// CHECK: %[[VAL_7:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_6]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[1,5,1,1],f32>
1609+
// CHECK: %[[VAL_8:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_6]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,5,1,1],f32>
1610+
// CHECK: %[[VAL_9:.*]] = torch.constant.int 1
1611+
// CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e-09
1612+
// CHECK: %[[VAL_11:.*]] = torch.aten.add.Scalar %[[VAL_8]], %[[VAL_10]], %[[VAL_9]] : !torch.vtensor<[1,5,1,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[1,5,1,1],f32>
1613+
// CHECK: %[[VAL_12:.*]] = torch.aten.sqrt %[[VAL_11]] : !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[1,5,1,1],f32>
1614+
// CHECK: %[[VAL_13:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_7]], %[[VAL_9]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1615+
// CHECK: %[[VAL_14:.*]] = torch.aten.div.Tensor %[[VAL_13]], %[[VAL_12]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[1,5,1,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1616+
// CHECK: return %[[VAL_14]] : !torch.vtensor<[3,5,2,2],f32>
1617+
// CHECK: }
1618+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1619+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1620+
}
1621+
1622+
// -----
1623+
1624+
// CHECK-LABEL: func.func @test_meanvarnorm_axes(
1625+
func.func @test_meanvarnorm_axes(%arg0: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1626+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
1627+
// CHECK: %[[VAL_0:.*]] = torch.constant.bool true
1628+
// CHECK: %[[VAL_1:.*]] = torch.constant.bool false
1629+
// CHECK: %[[VAL_2:.*]] = torch.constant.none
1630+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
1631+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
1632+
// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
1633+
// CHECK: %[[VAL_6:.*]] = torch.aten.mean.dim %[[ARG0]], %[[VAL_5]], %[[VAL_0]], %[[VAL_2]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2,1],f32>
1634+
// CHECK: %[[VAL_7:.*]] = torch.aten.var.dim %[[ARG0]], %[[VAL_5]], %[[VAL_1]], %[[VAL_0]] : !torch.vtensor<[3,5,2,2],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,1,2,1],f32>
1635+
// CHECK: %[[VAL_8:.*]] = torch.constant.int 1
1636+
// CHECK: %[[VAL_9:.*]] = torch.constant.float 1.000000e-09
1637+
// CHECK: %[[VAL_10:.*]] = torch.aten.add.Scalar %[[VAL_7]], %[[VAL_9]], %[[VAL_8]] : !torch.vtensor<[3,1,2,1],f32>, !torch.float, !torch.int -> !torch.vtensor<[3,1,2,1],f32>
1638+
// CHECK: %[[VAL_11:.*]] = torch.aten.sqrt %[[VAL_10]] : !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,1,2,1],f32>
1639+
// CHECK: %[[VAL_12:.*]] = torch.aten.sub.Tensor %[[ARG0]], %[[VAL_6]], %[[VAL_8]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32>, !torch.int -> !torch.vtensor<[3,5,2,2],f32>
1640+
// CHECK: %[[VAL_13:.*]] = torch.aten.div.Tensor %[[VAL_12]], %[[VAL_11]] : !torch.vtensor<[3,5,2,2],f32>, !torch.vtensor<[3,1,2,1],f32> -> !torch.vtensor<[3,5,2,2],f32>
1641+
// CHECK: return %[[VAL_13]] : !torch.vtensor<[3,5,2,2],f32>
1642+
// CHECK: }
1643+
%0 = torch.operator "onnx.MeanVarianceNormalization"(%arg0) {torch.onnx.axes = [1 : si64, 3 : si64]} : (!torch.vtensor<[3,5,2,2],f32>) -> !torch.vtensor<[3,5,2,2],f32>
1644+
return %0 : !torch.vtensor<[3,5,2,2],f32>
1645+
}
1646+
1647+
// -----
1648+
15981649
// CHECK-LABEL: func.func @test_not_2d
15991650
func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
16001651
// CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1>

0 commit comments

Comments
 (0)