Skip to content

Commit 06c31f9

Browse files
committed
Fix macro func variable extraction & add tests
1 parent 11400fb commit 06c31f9

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

sqlmesh/core/model/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,22 @@ def make_python_env(
160160
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161161
references = set()
162162

163-
for n in macro_func.walk():
164-
if n is macro_func:
165-
continue
163+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
164+
# they will be handled in a separate call of _extract_macro_func_variable_references.
165+
def _prune_nested_macro_func(expression: exp.Expression) -> bool:
166+
return (
167+
type(n) is d.MacroFunc
168+
and n is not macro_func
169+
and n.this.name.lower() not in (c.VAR, c.BLUEPRINT_VAR)
170+
)
166171

167-
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168-
# they will be handled in a separate call of _extract_macro_func_variable_references.
169-
if isinstance(n, d.MacroFunc):
172+
for n in macro_func.walk(prune=_prune_nested_macro_func):
173+
if type(n) is d.MacroFunc:
170174
this = n.this
171-
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172-
references.add(this.expressions[0].this.lower())
175+
args = this.expressions
176+
177+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string:
178+
references.add(args[0].this.lower())
173179
elif isinstance(n, d.MacroVar):
174180
references.add(n.name.lower())
175181
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:

tests/core/test_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10835,3 +10835,22 @@ def test_datetime_without_timezone_variable_redshift() -> None:
1083510835
model.render_query_or_raise().sql("redshift")
1083610836
== '''SELECT CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS "test_time_col"'''
1083710837
)
10838+
10839+
10840+
@pytest.mark.parametrize(
10841+
"macro_func, variables",
10842+
[
10843+
("@M(@v1)", {"v1"}),
10844+
("@M(@{v1})", {"v1"}),
10845+
("@M(@SQL('@v1'))", {"v1"}),
10846+
("@M(@'@{v1}_foo')", {"v1"}),
10847+
("@M1(@VAR('v1'))", {"v1"}),
10848+
("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v3"}),
10849+
("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}),
10850+
],
10851+
)
10852+
def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None:
10853+
from sqlmesh.core.model.common import _extract_macro_func_variable_references
10854+
10855+
macro_func_ast = parse_one(macro_func)
10856+
assert _extract_macro_func_variable_references(macro_func_ast) == variables

0 commit comments

Comments
 (0)