1
- import random
2
1
from collections import Counter , OrderedDict
2
+ from torch .ao .ns .fx .utils import compute_sqnr
3
3
from typing import Any , Callable , Dict , List , Optional , Tuple
4
4
5
+ import math
6
+ import random
5
7
import torch
6
8
7
9
from executorch .backends .test .harness .stages import (
@@ -302,17 +304,18 @@ def run_method_and_compare_outputs(
302
304
atol = 1e-03 ,
303
305
rtol = 1e-03 ,
304
306
qtol = 0 ,
307
+ snr : float | None = None ,
305
308
):
306
309
number_of_runs = 1 if inputs is not None else num_runs
307
310
reference_stage = self .stages [StageType .EXPORT ]
308
311
309
312
stage = stage or self .cur
310
313
311
- print (f"Comparing Stage { stage } with Stage { reference_stage } " )
314
+ # print(f"Comparing Stage {stage} with Stage {reference_stage}")
312
315
for run_iteration in range (number_of_runs ):
313
316
inputs_to_run = inputs if inputs else next (self .generate_random_inputs ())
314
317
input_shapes = [generated_input .shape for generated_input in inputs_to_run ]
315
- print (f"Run { run_iteration } with input shapes: { input_shapes } " )
318
+ # print(f"Run {run_iteration} with input shapes: {input_shapes}")
316
319
317
320
# Reference output (and quantization scale)
318
321
(
@@ -325,13 +328,13 @@ def run_method_and_compare_outputs(
325
328
# Output from running artifact at stage
326
329
stage_output = self .stages [stage ].run_artifact (inputs_to_run )
327
330
self ._compare_outputs (
328
- reference_output , stage_output , quantization_scale , atol , rtol , qtol
331
+ reference_output , stage_output , quantization_scale , atol , rtol , qtol , snr
329
332
)
330
333
331
334
return self
332
335
333
336
@staticmethod
334
- def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 ):
337
+ def _assert_outputs_equal (model_output , ref_output , atol = 1e-03 , rtol = 1e-03 , snr : float | None = None ):
335
338
"""
336
339
Helper testing function that asserts that the model output and the reference output
337
340
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -356,15 +359,18 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
356
359
f"\t Mismatched count: { (model != ref ).sum ().item ()} / { model .numel ()} \n "
357
360
)
358
361
else :
362
+ computed_snr = compute_sqnr (model .to (torch .float ), ref .to (torch .float ))
363
+ snr = snr or float ("-inf" )
364
+
359
365
assert torch .allclose (
360
366
model ,
361
367
ref ,
362
368
atol = atol ,
363
369
rtol = rtol ,
364
370
equal_nan = True ,
365
- ), (
371
+ ) and computed_snr >= snr or math . isnan ( computed_snr ) , (
366
372
f"Output { i } does not match reference output.\n "
367
- f"\t Given atol: { atol } , rtol: { rtol } .\n "
373
+ f"\t Given atol: { atol } , rtol: { rtol } , snr: { snr } .\n "
368
374
f"\t Output tensor shape: { model .shape } , dtype: { model .dtype } \n "
369
375
f"\t Difference: max: { torch .max (model - ref )} , abs: { torch .max (torch .abs (model - ref ))} , mean abs error: { torch .mean (torch .abs (model - ref ).to (torch .double ))} .\n "
370
376
f"\t -- Model vs. Reference --\n "
@@ -373,8 +379,10 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
373
379
f"\t Mean: { model .to (torch .double ).mean ()} , { ref .to (torch .double ).mean ()} \n "
374
380
f"\t Max: { model .max ()} , { ref .max ()} \n "
375
381
f"\t Min: { model .min ()} , { ref .min ()} \n "
382
+ f"\t SNR: { computed_snr } \n "
376
383
)
377
384
385
+
378
386
@staticmethod
379
387
def _compare_outputs (
380
388
reference_output ,
@@ -383,6 +391,7 @@ def _compare_outputs(
383
391
atol = 1e-03 ,
384
392
rtol = 1e-03 ,
385
393
qtol = 0 ,
394
+ snr : float | None = None ,
386
395
):
387
396
"""
388
397
Compares the original of the original nn module with the output of the generated artifact.
@@ -405,6 +414,7 @@ def _compare_outputs(
405
414
reference_output ,
406
415
atol = atol ,
407
416
rtol = rtol ,
417
+ snr = snr ,
408
418
)
409
419
410
420
@staticmethod
0 commit comments