1
+ import math
1
2
import random
2
3
from collections import Counter , OrderedDict
3
4
from typing import Any , Callable , Dict , List , Optional , Tuple
17
18
ToExecutorch ,
18
19
)
19
20
from executorch .exir .dim_order_utils import get_memory_format
21
+ from torch .ao .ns .fx .utils import compute_sqnr
20
22
21
23
from torch .export import ExportedProgram
22
24
from torch .testing import FileCheck
@@ -302,13 +304,13 @@ def run_method_and_compare_outputs(
302
304
atol = 1e-03 ,
303
305
rtol = 1e-03 ,
304
306
qtol = 0 ,
307
+ snr : float | None = None ,
305
308
):
306
309
number_of_runs = 1 if inputs is not None else num_runs
307
310
reference_stage = self .stages [StageType .EXPORT ]
308
311
309
312
stage = stage or self .cur
310
313
311
- print (f"Comparing Stage { stage } with Stage { reference_stage } " )
312
314
for run_iteration in range (number_of_runs ):
313
315
inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
314
316
input_shapes = [
@@ -328,13 +330,21 @@ def run_method_and_compare_outputs(
328
330
# Output from running artifact at stage
329
331
stage_output = self .stages [stage ].run_artifact (inputs_to_run )
330
332
self ._compare_outputs (
331
- reference_output , stage_output , quantization_scale , atol , rtol , qtol
333
+ reference_output ,
334
+ stage_output ,
335
+ quantization_scale ,
336
+ atol ,
337
+ rtol ,
338
+ qtol ,
339
+ snr ,
332
340
)
333
341
334
342
return self
335
343
336
344
@staticmethod
337
- def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
345
+ def _assert_outputs_equal (
346
+ model_output , ref_output , atol = 1e-03 , rtol = 1e-03 , snr : float | None = None
347
+ ):
338
348
"""
339
349
Helper testing function that asserts that the model output and the reference output
340
350
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -359,15 +369,22 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
359
369
f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
360
370
)
361
371
else :
362
- assert torch .allclose (
363
- model ,
364
- ref ,
365
- atol = atol ,
366
- rtol = rtol ,
367
- equal_nan = True ,
372
+ computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
373
+ snr = snr or float ("-inf" )
374
+
375
+ assert (
376
+ torch .allclose (
377
+ model ,
378
+ ref ,
379
+ atol = atol ,
380
+ rtol = rtol ,
381
+ equal_nan = True ,
382
+ )
383
+ and computed_snr >= snr
384
+ or math .isnan (computed_snr )
368
385
), (
369
386
f"Output { i } does not match reference output.\n "
370
- f"\t Given atol: { atol } , rtol: { rtol } .\n "
387
+ f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
371
388
f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
372
389
f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
373
390
f"\t -- Model vs. Reference --\n "
@@ -376,6 +393,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
376
393
f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
377
394
f"\t Max: { model .max ()} , { ref .max ()} \n "
378
395
f"\t Min: { model .min ()} , { ref .min ()} \n "
396
+ f"\t SNR: { computed_snr } \n "
379
397
)
380
398
381
399
@staticmethod
@@ -386,6 +404,7 @@ def _compare_outputs(
386
404
atol = 1e-03 ,
387
405
rtol = 1e-03 ,
388
406
qtol = 0 ,
407
+ snr : float | None = None ,
389
408
):
390
409
"""
391
410
Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +427,7 @@ def _compare_outputs(
408
427
reference_output ,
409
428
atol = atol ,
410
429
rtol = rtol ,
430
+ snr = snr ,
411
431
)
412
432
413
433
@staticmethod
0 commit comments