@@ -851,7 +851,7 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf
851
851
// -----
852
852
853
853
854
- // CHECK-LABEL: func @pixel_unshuffle
854
+ // CHECK-LABEL: func @pixel_unshuffle_static
855
855
// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
856
856
// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
857
857
// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
@@ -865,8 +865,39 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf
865
865
// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32>
866
866
// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32>
867
867
// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[1,32,2,2],f32>
868
- func.func @pixel_unshuffle (%arg0: !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >) -> !torch.vtensor <[1 ,32 ,2 ,2 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
868
+ func.func @pixel_unshuffle_static (%arg0: !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >) -> !torch.vtensor <[1 ,32 ,2 ,2 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
869
869
%int2 = torch.constant.int 2
870
870
%0 = torch.aten.pixel_unshuffle %arg0 , %int2 : !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >, !torch.int -> !torch.vtensor <[1 ,32 ,2 ,2 ],f32 >
871
871
return %0 : !torch.vtensor <[1 ,32 ,2 ,2 ],f32 >
872
872
}
873
+
874
+
875
+ // -----
876
+
877
+
878
+ // CHECK-LABEL: func @pixel_unshuffle_fulldynamic
879
+ // CHECK-DAG: %[[C2:.*]] = torch.constant.int 2
880
+ // CHECK-DAG: %[[C0:.*]] = torch.constant.int 0
881
+ // CHECK-DAG: %[[C1:.*]] = torch.constant.int 1
882
+ // CHECK-DAG: %[[C3:.*]] = torch.constant.int 3
883
+ // CHECK-DAG: %[[C4:.*]] = torch.constant.int 4
884
+ // CHECK-DAG: %[[C5:.*]] = torch.constant.int 5
885
+ // CHECK: %[[INC:.*]] = torch.aten.size.int %[[ARG0]], %[[C1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
886
+ // CHECK: %[[INH:.*]] = torch.aten.size.int %[[ARG0]], %[[C2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
887
+ // CHECK: %[[INW:.*]] = torch.aten.size.int %[[ARG0]], %[[C3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
888
+ // CHECK: %[[OUTC:.*]] = torch.aten.mul.int %[[INC]], %[[C4]] : !torch.int, !torch.int -> !torch.int
889
+ // CHECK: %[[OUTH:.*]] = torch.aten.floordiv.int %[[INH]], %[[C2]] : !torch.int, !torch.int -> !torch.int
890
+ // CHECK: %[[OUTW:.*]] = torch.aten.floordiv.int %[[INW]], %[[C2]] : !torch.int, !torch.int -> !torch.int
891
+ // CHECK: %[[SIZE0:.*]] = torch.aten.size.int %[[ARG0]], %[[C0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int
892
+ // CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
893
+ // CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[OUTH]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?],f32>
894
+ // CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[OUTW]] : !torch.vtensor<[?,?,?,2,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,2,?,2],f32>
895
+ // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[?,?,?,2,?,2],f32>, !torch.list<int> -> !torch.vtensor<[?,?,2,2,?,?],f32>
896
+ // CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[?,?,2,2,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,4,?,?],f32>
897
+ // CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[?,?,4,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>
898
+ // CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[?,?,?,?],f32>
899
+ func.func @pixel_unshuffle_fulldynamic (%arg0: !torch.vtensor <[?,?,?,?],f32 >) -> !torch.vtensor <[?,?,?,?],f32 > attributes {torch.assume_strict_symbolic_shapes } {
900
+ %int2 = torch.constant.int 2
901
+ %0 = torch.aten.pixel_unshuffle %arg0 , %int2 : !torch.vtensor <[?,?,?,?],f32 >, !torch.int -> !torch.vtensor <[?,?,?,?],f32 >
902
+ return %0 : !torch.vtensor <[?,?,?,?],f32 >
903
+ }
0 commit comments