Skip to content

Commit 930a00c

Browse files
committed
[Backend Tester] Add SNR validation
ghstack-source-id: c8c9103 ghstack-comment-id: 3129418402 Pull-Request: #12924
1 parent 47f085e commit 930a00c

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

backends/test/harness/tester.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import random
23
from collections import Counter, OrderedDict
34
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -17,6 +18,7 @@
1718
ToExecutorch,
1819
)
1920
from executorch.exir.dim_order_utils import get_memory_format
21+
from torch.ao.ns.fx.utils import compute_sqnr
2022

2123
from torch.export import ExportedProgram
2224
from torch.testing import FileCheck
@@ -302,13 +304,13 @@ def run_method_and_compare_outputs(
302304
atol=1e-03,
303305
rtol=1e-03,
304306
qtol=0,
307+
snr: float | None = None,
305308
):
306309
number_of_runs = 1 if inputs is not None else num_runs
307310
reference_stage = self.stages[StageType.EXPORT]
308311

309312
stage = stage or self.cur
310313

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312314
for run_iteration in range(number_of_runs):
313315
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314316
input_shapes = [
@@ -328,13 +330,21 @@ def run_method_and_compare_outputs(
328330
# Output from running artifact at stage
329331
stage_output = self.stages[stage].run_artifact(inputs_to_run)
330332
self._compare_outputs(
331-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
333+
reference_output,
334+
stage_output,
335+
quantization_scale,
336+
atol,
337+
rtol,
338+
qtol,
339+
snr,
332340
)
333341

334342
return self
335343

336344
@staticmethod
337-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
345+
def _assert_outputs_equal(
346+
model_output, ref_output, atol=1e-03, rtol=1e-03, snr: float | None = None
347+
):
338348
"""
339349
Helper testing function that asserts that the model output and the reference output
340350
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -359,15 +369,22 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
359369
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
360370
)
361371
else:
362-
assert torch.allclose(
363-
model,
364-
ref,
365-
atol=atol,
366-
rtol=rtol,
367-
equal_nan=True,
372+
computed_snr = compute_sqnr(model.to(torch.float), ref.to(torch.float))
373+
snr = snr or float("-inf")
374+
375+
assert (
376+
torch.allclose(
377+
model,
378+
ref,
379+
atol=atol,
380+
rtol=rtol,
381+
equal_nan=True,
382+
)
383+
and computed_snr >= snr
384+
or math.isnan(computed_snr)
368385
), (
369386
f"Output {i} does not match reference output.\n"
370-
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
387+
f"\tGiven atol: {atol}, rtol: {rtol}, snr: {snr}.\n"
371388
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
372389
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
373390
f"\t-- Model vs. Reference --\n"
@@ -376,6 +393,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
376393
f"\t Mean: {model.to(torch.double).mean()}, {ref.to(torch.double).mean()}\n"
377394
f"\t Max: {model.max()}, {ref.max()}\n"
378395
f"\t Min: {model.min()}, {ref.min()}\n"
396+
f"\t SNR: {computed_snr}\n"
379397
)
380398

381399
@staticmethod
@@ -386,6 +404,7 @@ def _compare_outputs(
386404
atol=1e-03,
387405
rtol=1e-03,
388406
qtol=0,
407+
snr: float | None = None,
389408
):
390409
"""
391410
Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +427,7 @@ def _compare_outputs(
408427
reference_output,
409428
atol=atol,
410429
rtol=rtol,
430+
snr=snr,
411431
)
412432

413433
@staticmethod

backends/test/suite/operators/test_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def test_abs_edge_cases(self, flow: TestFlow) -> None:
4949
# Test edge cases
5050

5151
# Tensor with infinity
52-
x = torch.tensor([float("inf"), float("-inf"), 1.0, -1.0])
52+
x = (torch.tensor([float("inf"), float("-inf"), 1.0, -1.0]),)
5353
self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False)
5454

5555
# Tensor with NaN
56-
x = torch.tensor([float("nan"), 1.0, -1.0])
56+
x = (torch.tensor([float("nan"), 1.0, -1.0]),)
5757
self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False)

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def build_result(
104104
# AssertionErrors to catch output mismatches, but this might catch more than that.
105105
try:
106106
tester.run_method_and_compare_outputs(
107-
inputs=None if generate_random_test_inputs else inputs
107+
inputs=None if generate_random_test_inputs else inputs,
108+
atol=5e-2,
109+
rtol=5e-2,
110+
snr=40,
108111
)
109112
except AssertionError as e:
110113
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)

0 commit comments

Comments
 (0)