Skip to content

Commit a726d55

Browse files
committed
Fix: Unit test CTE failures not being captured
1 parent a94c4f0 commit a726d55

File tree

4 files changed

+153
-42
lines changed

4 files changed

+153
-42
lines changed

sqlmesh/core/console.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,13 +2163,12 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
21632163
self._print("-" * divider_length)
21642164
self._print("Test Failure Summary", style="red")
21652165
self._print("=" * divider_length)
2166-
failures = len(result.failures) + len(result.errors)
2166+
fail_and_error_tests = result.get_fail_and_error_tests()
21672167
self._print(f"{message} \n")
21682168

2169-
self._print(f"Failed tests ({failures}):")
2170-
for test, _ in result.failures + result.errors:
2171-
if isinstance(test, ModelTest):
2172-
self._print(f" • {test.path}::{test.test_name}")
2169+
self._print(f"Failed tests ({len(fail_and_error_tests)}):")
2170+
for test in fail_and_error_tests:
2171+
self._print(f" • {test.path}::{test.test_name}")
21732172
self._print("=" * divider_length, end="\n\n")
21742173

21752174
def _captured_unit_test_results(self, result: ModelTextTestResult) -> str:
@@ -2721,28 +2720,15 @@ def _log_test_details(
27212720
Args:
27222721
result: The unittest test result that contains metrics like num success, fails, ect.
27232722
"""
2724-
27252723
if result.wasSuccessful():
27262724
self._print("\n", end="")
27272725
return
27282726

2729-
errors = result.errors
2730-
failures = result.failures
2731-
skipped = result.skipped
2732-
2733-
infos = []
2734-
if failures:
2735-
infos.append(f"failures={len(failures)}")
2736-
if errors:
2737-
infos.append(f"errors={len(errors)}")
2738-
if skipped:
2739-
infos.append(f"skipped={skipped}")
2740-
27412727
if unittest_char_separator:
27422728
self._print(f"\n{unittest.TextTestResult.separator1}\n\n", end="")
27432729

27442730
for (test_case, failure), test_failure_tables in zip_longest( # type: ignore
2745-
failures, result.failure_tables
2731+
result.failures, result.failure_tables
27462732
):
27472733
self._print(unittest.TextTestResult.separator2)
27482734
self._print(f"FAIL: {test_case}")
@@ -2758,7 +2744,7 @@ def _log_test_details(
27582744
self._print(failure_table)
27592745
self._print("\n", end="")
27602746

2761-
for test_case, error in errors:
2747+
for test_case, error in result.errors:
27622748
self._print(unittest.TextTestResult.separator2)
27632749
self._print(f"ERROR: {test_case}")
27642750
self._print(f"{unittest.TextTestResult.separator2}")
@@ -3080,27 +3066,27 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
30803066
fail_shared_style = {**shared_style, **fail_color}
30813067
header = str(h("span", {"style": fail_shared_style}, "-" * divider_length))
30823068
message = str(h("span", {"style": fail_shared_style}, "Test Failure Summary"))
3069+
fail_and_error_tests = result.get_fail_and_error_tests()
30833070
failed_tests = [
30843071
str(
30853072
h(
30863073
"span",
30873074
{"style": fail_shared_style},
3088-
f"Failed tests ({len(result.failures) + len(result.errors)}):",
3075+
f"Failed tests ({len(fail_and_error_tests)}):",
30893076
)
30903077
)
30913078
]
30923079

3093-
for test, _ in result.failures + result.errors:
3094-
if isinstance(test, ModelTest):
3095-
failed_tests.append(
3096-
str(
3097-
h(
3098-
"span",
3099-
{"style": fail_shared_style},
3100-
f" • {test.model.name}::{test.test_name}",
3101-
)
3080+
for test in fail_and_error_tests:
3081+
failed_tests.append(
3082+
str(
3083+
h(
3084+
"span",
3085+
{"style": fail_shared_style},
3086+
f" • {test.model.name}::{test.test_name}",
31023087
)
31033088
)
3089+
)
31043090
failures = "<br>".join(failed_tests)
31053091
footer = str(h("span", {"style": fail_shared_style}, "=" * divider_length))
31063092
error_output = widgets.Textarea(output, layout={"height": "300px", "width": "100%"})
@@ -3508,10 +3494,10 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
35083494
self._log_test_details(result, unittest_char_separator=False)
35093495
self._print("```\n\n")
35103496

3511-
failures = len(result.failures) + len(result.errors)
3497+
fail_and_error_tests = result.get_fail_and_error_tests()
35123498
self._print(f"**{message}**\n")
3513-
self._print(f"**Failed tests ({failures}):**")
3514-
for test, _ in result.failures + result.errors:
3499+
self._print(f"**Failed tests ({len(fail_and_error_tests)}):**")
3500+
for test in fail_and_error_tests:
35153501
if isinstance(test, ModelTest):
35163502
self._print(f" • `{test.model.name}`::`{test.test_name}`\n\n")
35173503

sqlmesh/core/test/definition.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,13 @@ def _to_hashable(x: t.Any) -> t.Any:
317317
#
318318
# This is a bit of a hack, but it's a way to get the best of both worlds.
319319
args: t.List[t.Any] = []
320+
321+
failed_subtest = ""
322+
323+
if subtest := getattr(self, "_subtest", None):
324+
if cte := subtest.params.get("cte"):
325+
failed_subtest = f" (CTE {cte})"
326+
320327
if expected.shape != actual.shape:
321328
_raise_if_unexpected_columns(expected.columns, actual.columns)
322329

@@ -325,13 +332,13 @@ def _to_hashable(x: t.Any) -> t.Any:
325332
missing_rows = _row_difference(expected, actual)
326333
if not missing_rows.empty:
327334
args[0] += f"\n\nMissing rows:\n\n{missing_rows}"
328-
args.append(df_to_table("Missing rows", missing_rows))
335+
args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows))
329336

330337
unexpected_rows = _row_difference(actual, expected)
331338

332339
if not unexpected_rows.empty:
333340
args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
334-
args.append(df_to_table("Unexpected rows", unexpected_rows))
341+
args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows))
335342

336343
else:
337344
diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"})
@@ -341,7 +348,8 @@ def _to_hashable(x: t.Any) -> t.Any:
341348
diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True)
342349
if self.verbosity == Verbosity.DEFAULT:
343350
args.extend(
344-
df_to_table("Data mismatch", df) for df in _split_df_by_column_pairs(diff)
351+
df_to_table(f"Data mismatch{failed_subtest}", df)
352+
for df in _split_df_by_column_pairs(diff)
345353
)
346354
else:
347355
from pandas import MultiIndex
@@ -351,7 +359,8 @@ def _to_hashable(x: t.Any) -> t.Any:
351359
col_diff = diff[col]
352360
if not col_diff.empty:
353361
table = df_to_table(
354-
f"[bold red]Column '{col}' mismatch[/bold red]", col_diff
362+
f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]",
363+
col_diff,
355364
)
356365
args.append(table)
357366

sqlmesh/core/test/result.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import typing as t
55
import unittest
66

7+
from sqlmesh.core.test.definition import ModelTest
8+
79
if t.TYPE_CHECKING:
810
ErrorType = t.Union[
911
t.Tuple[type[BaseException], BaseException, types.TracebackType],
@@ -42,7 +44,10 @@ def addSubTest(
4244
exctype, value, tb = err
4345
err = (exctype, value, None) # type: ignore
4446

45-
super().addSubTest(test, subtest, err)
47+
if err[0] and issubclass(err[0], test.failureException):
48+
self.addFailure(test, err)
49+
else:
50+
self.addError(test, err)
4651

4752
def _print_char(self, char: str) -> None:
4853
from sqlmesh.core.console import TerminalConsole
@@ -117,4 +122,14 @@ def merge(self, other: ModelTextTestResult) -> None:
117122
skipped_args = other.skipped[0]
118123
self.addSkip(skipped_args[0], skipped_args[1])
119124

120-
self.testsRun += 1
125+
self.testsRun += other.testsRun
126+
127+
def get_fail_and_error_tests(self) -> t.List[ModelTest]:
128+
# If tests contain failed subtests (e.g testing CTE outputs) we don't want
129+
# to report it as different test failures
130+
test_name_to_test = {
131+
test.test_name: test
132+
for test, _ in self.failures + self.errors
133+
if isinstance(test, ModelTest)
134+
}
135+
return list(test_name_to_test.values())

tests/core/test_test.py

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathlib import Path
77
import unittest
88
from unittest.mock import call, patch
9-
from shutil import copyfile, rmtree
9+
from shutil import rmtree
1010

1111
import pandas as pd # noqa: TID253
1212
import pytest
@@ -87,6 +87,7 @@ def _check_successful_or_raise(
8787
assert result is not None
8888
if not result.wasSuccessful():
8989
error_or_failure_traceback = (result.errors or result.failures)[0][1]
90+
print(error_or_failure_traceback)
9091
if expected_msg:
9192
assert expected_msg in error_or_failure_traceback
9293
else:
@@ -2316,6 +2317,13 @@ def test_test_with_resolve_template_macro(tmp_path: Path):
23162317

23172318
@use_terminal_console
23182319
def test_test_output(tmp_path: Path) -> None:
2320+
def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None:
2321+
with open(test_file, "r") as file:
2322+
filedata = file.read()
2323+
2324+
with open(new_test_file, "w") as file:
2325+
file.write(filedata.replace("test_example_full_model", f"test_{index}"))
2326+
23192327
init_example_project(tmp_path, engine_type="duckdb")
23202328

23212329
original_test_file = tmp_path / "tests" / "test_full_model.yaml"
@@ -2407,8 +2415,8 @@ def test_test_output(tmp_path: Path) -> None:
24072415

24082416
# Case 3: Assert that concurrent execution is working properly
24092417
for i in range(50):
2410-
copyfile(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml")
2411-
copyfile(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml")
2418+
copy_test_file(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml", i)
2419+
copy_test_file(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml", i)
24122420

24132421
with capture_output() as captured_output:
24142422
context.test()
@@ -3327,3 +3335,96 @@ def execute(context: ExecutionContext, **kwargs: t.Any) -> pd.DataFrame:
33273335
context=context,
33283336
)
33293337
_check_successful_or_raise(test_default_vars.run())
3338+
3339+
3340+
@use_terminal_console
3341+
def test_cte_failure(tmp_path: Path) -> None:
3342+
models_dir = tmp_path / "models"
3343+
models_dir.mkdir()
3344+
(models_dir / "foo.sql").write_text(
3345+
"""
3346+
MODEL (
3347+
name test.foo,
3348+
kind full
3349+
);
3350+
3351+
with model_cte as (
3352+
SELECT 1 AS id
3353+
)
3354+
SELECT id FROM model_cte
3355+
"""
3356+
)
3357+
3358+
config = Config(
3359+
default_connection=DuckDBConnectionConfig(),
3360+
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
3361+
)
3362+
context = Context(paths=tmp_path, config=config)
3363+
3364+
expected_cte_failure_output = """Data mismatch (CTE "model_cte")
3365+
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
3366+
┃ Row ┃ id: Expected ┃ id: Actual ┃
3367+
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
3368+
│ 0 │ 2 │ 1 │
3369+
└──────────┴─────────────────────────┴─────────────────────┘"""
3370+
3371+
expected_query_failure_output = """Data mismatch
3372+
┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
3373+
┃ Row ┃ id: Expected ┃ id: Actual ┃
3374+
┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
3375+
│ 0 │ 2 │ 1 │
3376+
└──────────┴─────────────────────────┴─────────────────────┘"""
3377+
3378+
# Case 1: Ensure that a single CTE failure is reported correctly
3379+
tests_dir = tmp_path / "tests"
3380+
tests_dir.mkdir()
3381+
(tests_dir / "test_foo.yaml").write_text(
3382+
"""
3383+
test_foo:
3384+
model: test.foo
3385+
outputs:
3386+
ctes:
3387+
model_cte:
3388+
rows:
3389+
- id: 2
3390+
query:
3391+
- id: 1
3392+
"""
3393+
)
3394+
3395+
with capture_output() as captured_output:
3396+
context.test()
3397+
3398+
output = captured_output.stdout
3399+
3400+
assert expected_cte_failure_output in output
3401+
assert expected_query_failure_output not in output
3402+
3403+
assert "Ran 1 tests" in output
3404+
assert "Failed tests (1)" in output
3405+
3406+
# Case 2: Ensure that both CTE and query failures are reported correctly
3407+
(tests_dir / "test_foo.yaml").write_text(
3408+
"""
3409+
test_foo:
3410+
model: test.foo
3411+
outputs:
3412+
ctes:
3413+
model_cte:
3414+
rows:
3415+
- id: 2
3416+
query:
3417+
- id: 2
3418+
"""
3419+
)
3420+
3421+
with capture_output() as captured_output:
3422+
context.test()
3423+
3424+
output = captured_output.stdout
3425+
3426+
assert expected_cte_failure_output in output
3427+
assert expected_query_failure_output in output
3428+
3429+
assert "Ran 1 tests" in output
3430+
assert "Failed tests (1)" in output

0 commit comments

Comments
 (0)