Skip to content

Commit 82215f3

Browse files
committed
add pixel_unshuffle_fulldynamic test
1 parent 111dfda commit 82215f3

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

test/Dialect/Torch/decompose-complex-ops.mlir

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf
851851
// -----
852852

853853

854-
// CHECK-LABEL: func @pixel_unshuffle
854+
// CHECK-LABEL: func @pixel_unshuffle_static
855855
// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
856856
// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
857857
// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
@@ -865,8 +865,39 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf
865865
// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32>
866866
// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
867867
// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[1,32,2,2],f32>
868-
func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
868+
func.func @pixel_unshuffle_static(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} {
869869
%int2 = torch.constant.int 2
870870
%0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
871871
return %0 : !torch.vtensor<[1,32,2,2],f32>
872872
}
873+
874+
875+
// -----
876+
877+
878+
// CHECK-LABEL: func @pixel_unshuffle_fulldynamic
879+
// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
880+
// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
881+
// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
882+
// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3
883+
// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4
884+
// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5
885+
// CHECK: %[[INC:.*]] = torch.aten.size.int %[[ARG0]], %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
886+
// CHECK: %[[INH:.*]] = torch.aten.size.int %[[ARG0]], %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
887+
// CHECK: %[[INW:.*]] = torch.aten.size.int %[[ARG0]], %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
888+
// CHECK: %[[OUTC:.*]] = torch.aten.mul.int %[[INC]], %[[C4]] : !torch.int, !torch.int -> !torch.int
889+
// CHECK: %[[OUTH:.*]] = torch.aten.floordiv.int %[[INH]], %[[C2]] : !torch.int, !torch.int -> !torch.int
890+
// CHECK: %[[OUTW:.*]] = torch.aten.floordiv.int %[[INW]], %[[C2]] : !torch.int, !torch.int -> !torch.int
891+
// CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
892+
// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
893+
// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[OUTH]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?],f32>
894+
// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[OUTW]] : !torch.vtensor<[?,?,?,2,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?,2],f32>
895+
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list<int> -> !torch.vtensor<[?,?,2,2,?,?],f32>
896+
// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[?,?,2,2,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,4,?,?],f32>
897+
// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[?,?,4,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
898+
// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[?,?,?,?],f32>
899+
func.func @pixel_unshuffle_fulldynamic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.assume_strict_symbolic_shapes} {
900+
%int2 = torch.constant.int 2
901+
%0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
902+
return %0 : !torch.vtensor<[?,?,?,?],f32>
903+
}

0 commit comments

Comments
 (0)