diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index 7019b734290..4bba6fdb819 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -1,3 +1,4 @@ +import math import random from collections import Counter, OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple @@ -17,6 +18,7 @@ ToExecutorch, ) from executorch.exir.dim_order_utils import get_memory_format +from torch.ao.ns.fx.utils import compute_sqnr from torch.export import ExportedProgram from torch.testing import FileCheck @@ -302,13 +304,13 @@ def run_method_and_compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, + snr: float | None = None, ): number_of_runs = 1 if inputs is not None else num_runs reference_stage = self.stages[StageType.EXPORT] stage = stage or self.cur - print(f"Comparing Stage {stage} with Stage {reference_stage}") for run_iteration in range(number_of_runs): inputs_to_run = inputs if inputs else next(self.generate_random_inputs()) input_shapes = [ @@ -328,13 +330,21 @@ def run_method_and_compare_outputs( # Output from running artifact at stage stage_output = self.stages[stage].run_artifact(inputs_to_run) self._compare_outputs( - reference_output, stage_output, quantization_scale, atol, rtol, qtol + reference_output, + stage_output, + quantization_scale, + atol, + rtol, + qtol, + snr, ) return self @staticmethod - def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): + def _assert_outputs_equal( + model_output, ref_output, atol=1e-03, rtol=1e-03, snr: float | None = None + ): """ Helper testing function that asserts that the model output and the reference output are equal with some tolerance. Due to numerical differences between eager mode and @@ -359,15 +369,22 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" ) else: - assert torch.allclose( - model, - ref, - atol=atol, - rtol=rtol, - equal_nan=True, + computed_snr = compute_sqnr(model.to(torch.float), ref.to(torch.float)) + snr = snr or float("-inf") + + assert ( + torch.allclose( + model, + ref, + atol=atol, + rtol=rtol, + equal_nan=True, + ) + and computed_snr >= snr + or math.isnan(computed_snr) ), ( f"Output {i} does not match reference output.\n" - f"\tGiven atol: {atol}, rtol: {rtol}.\n" + f"\tGiven atol: {atol}, rtol: {rtol}, snr: {snr}.\n" f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n" f"\t-- Model vs. Reference --\n" @@ -376,6 +393,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): f"\t Mean: {model.to(torch.double).mean()}, {ref.to(torch.double).mean()}\n" f"\t Max: {model.max()}, {ref.max()}\n" f"\t Min: {model.min()}, {ref.min()}\n" + f"\t SNR: {computed_snr}\n" ) @staticmethod @@ -386,6 +404,7 @@ def _compare_outputs( atol=1e-03, rtol=1e-03, qtol=0, + snr: float | None = None, ): """ Compares the original of the original nn module with the output of the generated artifact. @@ -408,6 +427,7 @@ def _compare_outputs( reference_output, atol=atol, rtol=rtol, + snr=snr, ) @staticmethod diff --git a/backends/test/suite/operators/test_abs.py b/backends/test/suite/operators/test_abs.py index fdfc6be671e..0f015e9ef68 100644 --- a/backends/test/suite/operators/test_abs.py +++ b/backends/test/suite/operators/test_abs.py @@ -49,9 +49,9 @@ def test_abs_edge_cases(self, flow: TestFlow) -> None: # Test edge cases # Tensor with infinity - x = torch.tensor([float("inf"), float("-inf"), 1.0, -1.0]) + x = (torch.tensor([float("inf"), float("-inf"), 1.0, -1.0]),) self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False) # Tensor with NaN - x = torch.tensor([float("nan"), 1.0, -1.0]) + x = (torch.tensor([float("nan"), 1.0, -1.0]),) self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False) diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index dd6e3586628..92d75b85ece 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -104,7 +104,10 @@ def build_result( # AssertionErrors to catch output mismatches, but this might catch more than that. try: tester.run_method_and_compare_outputs( - inputs=None if generate_random_test_inputs else inputs + inputs=None if generate_random_test_inputs else inputs, + atol=5e-2, + rtol=5e-2, + snr=40, ) except AssertionError as e: return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)