5757from executorch .exir .passes import dead_code_elimination_pass
5858
5959from parameterized .parameterized import parameterized
60- from torch ._ops import OpOverload
6160from torch .fx .passes .infra .pass_base import PassResult
6261
6362
@@ -87,36 +86,46 @@ def assertTargetCountsEqual(
8786
8887 @parameterized .expand (
8988 [
90- # Regular MM
91- [(64 , 33 ), (33 , 128 )],
92- # Batched MM
93- [(2 , 48 , 48 ), (2 , 48 , 48 )],
94- ]
89+ (
90+ "regular" ,
91+ (64 , 33 ), # x_shape
92+ (33 , 128 ), # y_shape
93+ ),
94+ (
95+ "batched" ,
96+ (2 , 48 , 48 ), # x_shape
97+ (2 , 48 , 48 ), # y_shape
98+ ),
99+ ],
95100 )
96101 @torch .no_grad ()
97102 def test_replace_matmul_with_transposed_matmul (
98103 self ,
104+ _ ,
99105 x_shape : Tuple [int ],
100106 y_shape : Tuple [int ],
101107 ) -> None :
102- class MatMul (torch .nn .Module ):
103- def __init__ (self ) -> None :
104- super (MatMul , self ).__init__ ()
105-
106- def forward (self , x , y ):
107- return torch .matmul (x , y )
108-
109- model = MatMul ()
110- X = torch .randn (x_shape )
111- Y = torch .randn (y_shape )
112- p = ReplaceMatmulWithTransposedMatmulPass ()
113- inputs = (X , Y )
114- graph_module = (
115- quantize_and_export_to_edge (model , inputs ).exported_program ().graph_module
108+ builder = GraphBuilder ()
109+ x = builder .placeholder ("x" , torch .randn (* x_shape , dtype = torch .float32 ))
110+ y = builder .placeholder ("y" , torch .randn (* y_shape , dtype = torch .float32 ))
111+ matmul = builder .call_operator (
112+ op = exir_ops .edge .cadence .quantized_matmul .default ,
113+ args = (
114+ x ,
115+ 0 , # X_zero_point
116+ y ,
117+ 0 , # Y_zero_point,
118+ None , # bias
119+ 1 , # out_multiplier
120+ 0 , # out_shift
121+ 0 , # out_zero_point
122+ False , # transposed=False
123+ ),
116124 )
117- # pyre-fixme[16]: Optional type has no attribute `graph_module`
118- graph_after_passes = p (graph_module ).graph_module
119-
125+ builder .output ([matmul ])
126+ original = builder .get_graph_module ()
127+ p = ReplaceMatmulWithTransposedMatmulPass ()
128+ graph_after_passes = cast (PassResult , p (original )).graph_module
120129 self .assertEqual (
121130 count_node (graph_after_passes , exir_ops .edge .aten .transpose_copy .int ),
122131 1 ,
@@ -130,33 +139,24 @@ def forward(self, x, y):
130139
131140 @parameterized .expand (
132141 [
133- [(3 , 5 ), (0 , 0 )],
134- [
135- (20 , 1 , 80 ),
136- (0 , 0 ),
137- ],
138- ]
142+ ("2d" , (3 , 5 ), [0 , 0 ]), # shape # padding
143+ ("3d" , (20 , 1 , 80 ), [0 , 0 , 0 ]), # shape # padding
144+ ],
139145 )
140146 @torch .no_grad ()
141147 def test_replace_constant_pad_nd_with_slice (
142- self , shape : Tuple [int ], padding : Tuple [int ]
148+ self , _ , shape : Tuple [int ], padding : Tuple [int ]
143149 ):
144- # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
145- class Padding (torch .nn .Module ):
146- def __init__ (self ):
147- super ().__init__ ()
148- self .padding = padding
149-
150- def forward (self , x : torch .Tensor ):
151- return F .pad (x , self .padding )
152-
153- model = Padding ()
154- x = torch .randn (shape )
155- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
156-
150+ builder = GraphBuilder ()
151+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
152+ matmul = builder .call_operator (
153+ op = exir_ops .edge .aten .constant_pad_nd .default ,
154+ args = (x , [0 , 0 , 0 , 0 ]),
155+ )
156+ builder .output ([matmul ])
157+ original = builder .get_graph_module ()
157158 p = ReplaceConstantPadNdWithSlicePass ()
158-
159- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
159+ graph_after_passes = cast (PassResult , p (original )).graph_module
160160 self .assertEqual (
161161 count_node (graph_after_passes , exir_ops .edge .aten .slice .Tensor ),
162162 1 ,
@@ -169,142 +169,140 @@ def forward(self, x: torch.Tensor):
169169
170170 @parameterized .expand (
171171 [
172- [(7 , 5 , 6 ), 1.23 ],
173- [(7 , 5 ), 2 ],
172+ ["3d" , (7 , 5 , 6 ), 1.23 ],
173+ ["2d" , (7 , 5 ), 2 ],
174+ ["1d" , (10 ,), 42949 ],
174175 ]
175176 )
176177 @torch .no_grad ()
177- def test_add_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
178- class Add (torch .nn .Module ):
179- def forward (self , x ):
180- return torch .ops .aten .add .Scalar (x , other )
181-
182- model = Add ()
178+ def test_add_replace_scalar_with_tensor_arg (
179+ self , _ , shape : Tuple [int ], other : float
180+ ):
183181 x = torch .randn (shape )
184- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
185-
182+ original = single_op_builder (
183+ placeholders = (x ,),
184+ op = exir_ops .edge .aten .add .Scalar ,
185+ args = (x , other ),
186+ )
186187 p = ReplaceScalarWithTensorArgPass ()
187-
188- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
188+ graph_after_passes = cast (PassResult , p (original )).graph_module
189189 self .assertEqual (
190190 count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
191191 1 ,
192192 )
193-
194193 self .assertEqual (
195194 count_node (graph_after_passes , exir_ops .edge .aten .add .Scalar ),
196195 0 ,
197196 )
198197
199198 @parameterized .expand (
200199 [
201- [(7 , 5 , 6 ), 1.23 ],
202- [(7 , 5 ), 2 ],
203- [(10 ), 42949 ],
200+ ["3d" , (7 , 5 , 6 ), 1.23 ],
201+ ["2d" , (7 , 5 ), 2 ],
202+ ["1d" , (10 , ), 42949 ],
204203 ]
205204 )
206205 @torch .no_grad ()
207- def test_sub_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
208- class Sub (torch .nn .Module ):
209- def forward (self , x ):
210- return torch .ops .aten .sub .Scalar (x , other )
211-
212- model = Sub ()
206+ def test_sub_replace_scalar_with_tensor_arg (
207+ self , _ , shape : Tuple [int ], other : float
208+ ):
213209 x = torch .randn (shape )
214- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
215-
210+ original = single_op_builder (
211+ placeholders = (x ,),
212+ op = exir_ops .edge .aten .sub .Scalar ,
213+ args = (x , other ),
214+ )
216215 p = ReplaceScalarWithTensorArgPass ()
217-
218- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
216+ graph_after_passes = cast (PassResult , p (original )).graph_module
219217 self .assertEqual (
220218 count_node (graph_after_passes , exir_ops .edge .aten .sub .Tensor ),
221219 1 ,
222220 )
223-
224221 self .assertEqual (
225222 count_node (graph_after_passes , exir_ops .edge .aten .sub .Scalar ),
226223 0 ,
227224 )
228225
229226 @parameterized .expand (
230227 [
231- [(7 , 5 , 6 ), 1.23 ],
232- [(7 , 5 ), 2 ],
233- [( 513 ), 3 ],
228+ ["3d" , (7 , 5 , 6 ), 1.23 ],
229+ ["2d" , (7 , 5 ), 2 ],
230+ ["1d" , ( 10 , ), 42949 ],
234231 ]
235232 )
236233 @torch .no_grad ()
237- def test_mul_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
238- class Mul (torch .nn .Module ):
239- def forward (self , x ):
240- return torch .ops .aten .mul .Scalar (x , other )
241-
242- model = Mul ()
234+ def test_mul_replace_scalar_with_tensor_arg (
235+ self , _ , shape : Tuple [int ], other : float
236+ ):
243237 x = torch .randn (shape )
244- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
245-
238+ original = single_op_builder (
239+ placeholders = (x ,),
240+ op = exir_ops .edge .aten .mul .Scalar ,
241+ args = (x , other ),
242+ )
246243 p = ReplaceScalarWithTensorArgPass ()
247-
248- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
244+ graph_after_passes = cast (PassResult , p (original )).graph_module
249245 self .assertEqual (
250246 count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
251247 1 ,
252248 )
253-
254249 self .assertEqual (
255250 count_node (graph_after_passes , exir_ops .edge .aten .mul .Scalar ),
256251 0 ,
257252 )
258253
259254 @parameterized .expand (
260255 [
261- [(7 , 5 , 6 ), 1.23 ],
262- [(7 , 5 ), 2 ],
256+ ["3d" , (7 , 5 , 6 ), 1.23 ],
257+ ["2d" , (7 , 5 ), 2 ],
258+ ["1d" , (10 ,), 42949 ],
263259 ]
264260 )
265261 @torch .no_grad ()
266262 def test_div_replace_scalar_with_tensor_arg (
267263 self ,
264+ _ ,
268265 shape : Tuple [int ],
269266 other : float ,
270267 ):
271- class Div (torch .nn .Module ):
272- def forward (self , x ):
273- return torch .ops .aten .div .Scalar (x , other )
274-
275- model = Div ()
276- x = torch .randn (shape )
277- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
278-
268+ x = torch .randn (* shape )
269+ original = single_op_builder (
270+ placeholders = (x ,),
271+ op = exir_ops .edge .aten .div .Scalar ,
272+ args = (x , other ),
273+ )
279274 p = ReplaceScalarWithTensorArgPass ()
280-
281- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
275+ graph_after_passes = cast (PassResult , p (original )).graph_module
282276 self .assertEqual (
283277 count_node (graph_after_passes , exir_ops .edge .aten .div .Tensor ),
284278 1 ,
285279 )
286-
287280 self .assertEqual (
288281 count_node (graph_after_passes , exir_ops .edge .aten .div .Scalar ),
289282 0 ,
290283 )
291284
292285 @parameterized .expand (
293286 [
294- [(2 , 3 , 5 , 6 )],
295- [(7 , 6 , 5 )],
296- [(4 , 4 )],
297- [(316 )],
287+ ["4d" , (2 , 3 , 5 , 6 )],
288+ ["3d" , (7 , 6 , 5 )],
289+ ["2d" , (4 , 4 )],
290+ ["1d" , (316 )],
298291 ]
299292 )
300293 @torch .no_grad ()
301- def test_replace_functionally_equivalent_op_targets_relu (self , shape : Tuple [int ]):
302- model = torch .nn .ReLU ()
294+ def test_replace_functionally_equivalent_op_targets_relu (
295+ self , _ , shape : Tuple [int ]
296+ ):
303297 x = torch .randn (shape )
304- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
298+ original = single_op_builder (
299+ placeholders = (x ,),
300+ op = exir_ops .edge .aten .relu_ .default ,
301+ args = (x ,),
302+ )
305303 p = ReplaceFunctionallyEquivalentOpTargets ()
304+ graph_after_passes = cast (PassResult , p (original )).graph_module
306305
307- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
308306 self .assertEqual (
309307 count_node (graph_after_passes , exir_ops .edge .aten .relu .default ),
310308 1 ,
@@ -315,56 +313,29 @@ def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]
315313 )
316314
317315 @parameterized .expand (
318- [
319- # split the only dimension
320- [(50 ,), i , 0 ]
321- for i in range (2 , 7 )
322- ]
323- + [
324- # split the leading dim
325- [(10 , 2 , 3 ), i , 0 ]
326- for i in range (2 , 7 )
327- ]
328- + [
329- # split the trailing dim
330- [(3 , 3 , 6 ), i , 2 ]
331- for i in range (2 , 6 )
332- ]
333- + [
334- # split the dim in the middle
335- [(3 , 5 , 14 , 2 , 3 ), i , 2 ]
336- for i in range (2 , 7 )
337- ]
316+ [["split_linear_tensor" , (50 ,), i , 0 ] for i in range (2 , 7 )]
317+ + [["split_leading_dim" , (10 , 2 , 3 ), i , 0 ] for i in range (2 , 7 )]
318+ + [["split_trailing_dim" , (3 , 3 , 6 ), i , 2 ] for i in range (2 , 6 )]
319+ + [["split_middle_dim" , (3 , 5 , 14 , 2 , 3 ), i , 2 ] for i in range (2 , 7 )]
338320 )
339321 @torch .no_grad ()
340322 def test_replace_functionally_equivalent_op_targets_unsafe_split (
341- self , shape : Tuple [int ], split_size : int , dim : int
323+ self , _ , shape : Tuple [int ], split_size : int , dim : int
342324 ):
343- class TensorSplitWithSizes (torch .nn .Module ):
344- def __init__ (self , split_size : int , dim : int , op : OpOverload ):
345- super ().__init__ ()
346- self .split_size = split_size
347- self .dim = dim
348- self .op = op
349-
350- def forward (self , x : torch .Tensor ):
351- return self .op (x , self .split_size , self .dim )
352-
353325 x = torch .randn (shape )
354- model = TensorSplitWithSizes (split_size , dim , torch .unsafe_split )
355- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
326+ original = single_op_builder (
327+ placeholders = (x ,),
328+ op = exir_ops .edge .aten .unsafe_split .Tensor ,
329+ args = (x , split_size , dim ),
330+ )
356331 p = ReplaceFunctionallyEquivalentOpTargets ()
357-
358- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
332+ graph_after_passes = cast (PassResult , p (original )).graph_module
359333 self .assertEqual (
360- count_node (
361- graph_after_passes , exir_ops .edge .aten .split_with_sizes_copy .default
362- ),
334+ count_node (graph_after_passes , exir_ops .edge .aten .split_copy .Tensor ),
363335 1 ,
364336 )
365337 self .assertEqual (
366- count_node (graph_after_passes , exir_ops .edge .aten .unsafe_split .Tensor ),
367- 0 ,
338+ count_node (graph_after_passes , exir_ops .edge .aten .unsafe_split .Tensor ), 0 , x
368339 )
369340
370341 @parameterized .expand (
0 commit comments