Skip to content

Commit cc5f9d7

Browse files
committed
Fix yet another bug
1 parent 7eb9222 commit cc5f9d7

File tree

2 files changed

+56
-26
lines changed

2 files changed

+56
-26
lines changed

sqlmesh/core/model/common.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from pathlib import Path
66

77
from astor import to_source
8-
from collections import defaultdict
98
from difflib import get_close_matches
109
from sqlglot import exp
1110
from sqlglot.helper import ensure_list
@@ -67,9 +66,9 @@ def make_python_env(
6766
# id(expr) -> false: expr appears under the AST of a macro function whose metadata status we don't yet know
6867
expr_under_metadata_macro_func: t.Dict[int, bool] = {}
6968

70-
# For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
71-
# v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
72-
macro_funcs_by_used_var: t.DefaultDict[str, t.Set[str]] = defaultdict(set)
69+
# For @m1(@m2(@x), @y), we'd get x -> m1 and y -> m1
70+
outermost_macro_func_ancestor_by_var: t.Dict[str, str] = {}
71+
visited_macro_funcs: t.Set[int] = set()
7372

7473
def _is_metadata_var(
7574
name: str, expression: exp.Expression, appears_in_metadata_expression: bool
@@ -131,13 +130,13 @@ def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool:
131130
used_variables[var_name] = _is_metadata_var(
132131
name, macro_func_or_var, is_metadata
133132
)
134-
else:
135-
var_refs, _expr_under_metadata_macro_func = (
133+
elif id(macro_func_or_var) not in visited_macro_funcs:
134+
var_refs, _expr_under_metadata_macro_func, _visited_macro_funcs = (
136135
_extract_macro_func_variable_references(macro_func_or_var, is_metadata)
137136
)
138137
expr_under_metadata_macro_func.update(_expr_under_metadata_macro_func)
139-
for var_ref in var_refs:
140-
macro_funcs_by_used_var[var_ref].add(name)
138+
visited_macro_funcs.update(_visited_macro_funcs)
139+
outermost_macro_func_ancestor_by_var |= {var_ref: name for var_ref in var_refs}
141140
elif macro_func_or_var.__class__ is d.MacroVar:
142141
name = macro_func_or_var.name.lower()
143142
if name in macros:
@@ -180,28 +179,22 @@ def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool:
180179
blueprint_variables=blueprint_variables,
181180
dialect=dialect,
182181
strict_resolution=strict_resolution,
183-
macro_funcs_by_used_var=macro_funcs_by_used_var,
182+
outermost_macro_func_ancestor_by_var=outermost_macro_func_ancestor_by_var,
184183
)
185184

186185

187186
def _extract_macro_func_variable_references(
188187
macro_func: exp.Expression,
189188
is_metadata: bool,
190-
) -> t.Tuple[t.Set[str], t.Dict[int, bool]]:
189+
) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]:
191190
references = set()
191+
visited_macro_funcs = set()
192192
expr_under_metadata_macro_func = {}
193193

194-
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
195-
# they will be handled in a separate call of _extract_macro_func_variable_references.
196-
def _prune_nested_macro_func(expression: exp.Expression) -> bool:
197-
return (
198-
type(expression) is d.MacroFunc
199-
and expression is not macro_func
200-
and expression.this.name.lower() not in (c.VAR, c.BLUEPRINT_VAR)
201-
)
202-
203-
for n in macro_func.walk(prune=_prune_nested_macro_func):
194+
for n in macro_func.walk():
204195
if type(n) is d.MacroFunc:
196+
visited_macro_funcs.add(id(n))
197+
205198
this = n.this
206199
args = this.expressions
207200

@@ -218,7 +211,7 @@ def _prune_nested_macro_func(expression: exp.Expression) -> bool:
218211
)
219212
expr_under_metadata_macro_func[id(n)] = is_metadata
220213

221-
return (references, expr_under_metadata_macro_func)
214+
return (references, expr_under_metadata_macro_func, visited_macro_funcs)
222215

223216

224217
def _add_variables_to_python_env(
@@ -228,7 +221,7 @@ def _add_variables_to_python_env(
228221
strict_resolution: bool = True,
229222
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
230223
dialect: DialectType = None,
231-
macro_funcs_by_used_var: t.Optional[t.DefaultDict[str, t.Set[str]]] = None,
224+
outermost_macro_func_ancestor_by_var: t.Optional[t.Dict[str, str]] = None,
232225
) -> t.Dict[str, Executable]:
233226
_, python_used_variables = parse_dependencies(
234227
python_env,
@@ -244,13 +237,13 @@ def _add_variables_to_python_env(
244237
# - They are only referenced in metadata-only contexts, such as `audits (...)`, virtual statements, etc
245238
# - They are only referenced in metadata-only macros, either as their arguments or within their definitions
246239
metadata_used_variables = set()
247-
for used_var, macro_names in (macro_funcs_by_used_var or {}).items():
240+
for used_var, outermost_macro_func in (outermost_macro_func_ancestor_by_var or {}).items():
248241
used_var_is_metadata = used_variables.get(used_var)
249242
if used_var_is_metadata is False:
250243
continue
251244

252-
if used_var_is_metadata or all(
253-
name in python_env and python_env[name].is_metadata for name in macro_names
245+
if used_var_is_metadata or (
246+
outermost_macro_func in python_env and python_env[outermost_macro_func].is_metadata
254247
):
255248
metadata_used_variables.add(used_var)
256249

tests/core/test_model.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10185,6 +10185,43 @@ def m2_references_v_non_metadata(evaluator):
1018510185
assert python_env.get(c.SQLMESH_VARS) == Executable.value({"v": 1})
1018610186

1018710187

10188+
def test_only_top_level_macro_func_impacts_var_descendant_metadata_status(tmp_path: Path) -> None:
10189+
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)
10190+
10191+
test_model = tmp_path / "models/test_model.sql"
10192+
test_model.parent.mkdir(parents=True, exist_ok=True)
10193+
test_model.write_text(
10194+
"MODEL (name test_model, kind FULL); @m1_metadata(@m2_non_metadata(@v)); SELECT 1 AS c;"
10195+
)
10196+
10197+
macro_code = """
10198+
from sqlmesh import macro
10199+
10200+
@macro(metadata_only=True)
10201+
def m1_metadata(evaluator, *args):
10202+
return None
10203+
10204+
@macro()
10205+
def m2_non_metadata(evaluator, *args):
10206+
return None"""
10207+
10208+
test_macros = tmp_path / "macros/test_macros.py"
10209+
test_macros.parent.mkdir(parents=True, exist_ok=True)
10210+
test_macros.write_text(macro_code)
10211+
10212+
ctx = Context(
10213+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}),
10214+
paths=tmp_path,
10215+
)
10216+
model = ctx.get_model("test_model")
10217+
10218+
python_env = model.python_env
10219+
10220+
assert len(python_env) == 3
10221+
assert set(python_env) > {"m1_metadata", "m2_non_metadata"}
10222+
assert python_env.get(c.SQLMESH_VARS_METADATA) == Executable.value({"v": 1}, is_metadata=True)
10223+
10224+
1018810225
def test_non_metadata_object_takes_precedence_over_metadata_only_object(tmp_path: Path) -> None:
1018910226
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)
1019010227

@@ -11242,7 +11279,7 @@ def test_render_query_optimize_query_false(assert_exp_eq, sushi_context):
1124211279
("@M(@SQL('@v1'))", {"v1"}),
1124311280
("@M(@'@{v1}_foo')", {"v1"}),
1124411281
("@M1(@VAR('v1'))", {"v1"}),
11245-
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v3"}),
11282+
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v2", "v3"}),
1124611283
("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}),
1124711284
],
1124811285
)

0 commit comments

Comments
 (0)