Skip to content

Commit 1b11fd3

Browse files
committed
[Backend Tester] Add SNR validation
ghstack-source-id: 8c6f72f ghstack-comment-id: 3129418402 Pull-Request: #12924
1 parent 73a523d commit 1b11fd3

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

backends/test/harness/tester.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
import random
21
from collections import Counter, OrderedDict
2+
from torch.ao.ns.fx.utils import compute_sqnr
33
from typing import Any, Callable, Dict, List, Optional, Tuple
44

5+
import math
6+
import random
57
import torch
68

79
from executorch.backends.test.harness.stages import (
@@ -302,17 +304,18 @@ def run_method_and_compare_outputs(
302304
atol=1e-03,
303305
rtol=1e-03,
304306
qtol=0,
307+
snr: float | None = None,
305308
):
306309
number_of_runs = 1 if inputs is not None else num_runs
307310
reference_stage = self.stages[StageType.EXPORT]
308311

309312
stage = stage or self.cur
310313

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
314+
#print(f"Comparing Stage {stage} with Stage {reference_stage}")
312315
for run_iteration in range(number_of_runs):
313316
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314317
input_shapes = [generated_input.shape for generated_input in inputs_to_run]
315-
print(f"Run {run_iteration} with input shapes: {input_shapes}")
318+
#print(f"Run {run_iteration} with input shapes: {input_shapes}")
316319

317320
# Reference output (and quantization scale)
318321
(
@@ -325,13 +328,13 @@ def run_method_and_compare_outputs(
325328
# Output from running artifact at stage
326329
stage_output = self.stages[stage].run_artifact(inputs_to_run)
327330
self._compare_outputs(
328-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
331+
reference_output, stage_output, quantization_scale, atol, rtol, qtol, snr
329332
)
330333

331334
return self
332335

333336
@staticmethod
334-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
337+
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03, snr: float | None = None):
335338
"""
336339
Helper testing function that asserts that the model output and the reference output
337340
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -356,15 +359,18 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
356359
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
357360
)
358361
else:
362+
computed_snr = compute_sqnr(model.to(torch.float), ref.to(torch.float))
363+
snr = snr or float("-inf")
364+
359365
assert torch.allclose(
360366
model,
361367
ref,
362368
atol=atol,
363369
rtol=rtol,
364370
equal_nan=True,
365-
), (
371+
) and computed_snr >= snr or math.isnan(computed_snr), (
366372
f"Output {i} does not match reference output.\n"
367-
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
373+
f"\tGiven atol: {atol}, rtol: {rtol}, snr: {snr}.\n"
368374
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
369375
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
370376
f"\t-- Model vs. Reference --\n"
@@ -373,8 +379,10 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
373379
f"\t Mean: {model.to(torch.double).mean()}, {ref.to(torch.double).mean()}\n"
374380
f"\t Max: {model.max()}, {ref.max()}\n"
375381
f"\t Min: {model.min()}, {ref.min()}\n"
382+
f"\t SNR: {computed_snr}\n"
376383
)
377384

385+
378386
@staticmethod
379387
def _compare_outputs(
380388
reference_output,
@@ -383,6 +391,7 @@ def _compare_outputs(
383391
atol=1e-03,
384392
rtol=1e-03,
385393
qtol=0,
394+
snr: float | None = None,
386395
):
387396
"""
388397
Compares the original of the original nn module with the output of the generated artifact.
@@ -405,6 +414,7 @@ def _compare_outputs(
405414
reference_output,
406415
atol=atol,
407416
rtol=rtol,
417+
snr=snr,
408418
)
409419

410420
@staticmethod

backends/test/suite/operators/test_abs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def test_abs_edge_cases(self, flow: TestFlow) -> None:
8686
)
8787

8888
# Tensor with infinity
89-
x = torch.tensor([float("inf"), float("-inf"), 1.0, -1.0])
89+
x = (torch.tensor([float("inf"), float("-inf"), 1.0, -1.0]),)
9090
self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False)
9191

9292
# Tensor with NaN
93-
x = torch.tensor([float("nan"), 1.0, -1.0])
93+
x = (torch.tensor([float("nan"), 1.0, -1.0]),)
9494
self._test_op(AbsModel(), (x,), flow, generate_random_test_inputs=False)
9595

9696
def test_abs_scalar(self, flow: TestFlow) -> None:

backends/test/suite/runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,10 @@ def build_result(
104104
# AssertionErrors to catch output mismatches, but this might catch more than that.
105105
try:
106106
tester.run_method_and_compare_outputs(
107-
None if generate_random_test_inputs else inputs
107+
inputs=None if generate_random_test_inputs else inputs,
108+
atol=5e-2,
109+
rtol=5e-2,
110+
snr=40,
108111
)
109112
except AssertionError as e:
110113
return build_result(TestResult.OUTPUT_MISMATCH_FAIL, e)

0 commit comments

Comments
 (0)