Skip to content

Commit 142dbc2

Browse files
committed
Refactor: address new bug, improve testing coverage
1 parent a8ee3a6 commit 142dbc2

File tree

2 files changed

+131
-38
lines changed

2 files changed

+131
-38
lines changed

sqlmesh/core/model/common.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,47 @@ def make_python_env(
5555
blueprint_variables = blueprint_variables or {}
5656

5757
used_macros: t.Dict[str, t.Tuple[MacroCallable, bool]] = {}
58-
used_variables = dict.fromkeys(referenced_variables or set(), False) # var -> is_metadata
58+
59+
# var -> True: var is metadata-only
60+
# var -> False: var is not metadata-only
61+
# var -> None: cannot determine whether var is metadata-only yet, need to walk macros first
62+
used_variables: t.Dict[str, t.Optional[bool]] = dict.fromkeys(
63+
referenced_variables or set(), False
64+
)
65+
66+
# id(expr) -> true: expr appears under the AST of a metadata-only macro function
67+
# id(expr) -> false: expr appears under the AST of a macro function whose metadata status we don't yet know
68+
expr_under_metadata_macro_func: t.Dict[int, bool] = {}
5969

6070
# For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
6171
# v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
6272
macro_funcs_by_used_var: t.DefaultDict[str, t.Set[str]] = defaultdict(set)
6373

74+
def _is_metadata_var(
75+
name: str, expression: exp.Expression, appears_in_metadata_expression: bool
76+
) -> t.Optional[bool]:
77+
is_metadata_so_far = used_variables.get(name, True)
78+
if is_metadata_so_far is False:
79+
return False
80+
81+
appears_under_metadata_macro_func = expr_under_metadata_macro_func.get(id(expression))
82+
if is_metadata_so_far and (
83+
appears_in_metadata_expression or appears_under_metadata_macro_func
84+
):
85+
return True
86+
87+
if appears_under_metadata_macro_func is False:
88+
return None
89+
90+
return False
91+
92+
def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool:
93+
if name in used_macros:
94+
is_metadata_so_far = used_macros[name][1]
95+
return is_metadata_so_far and appears_in_metadata_expression
96+
97+
return appears_in_metadata_expression
98+
6499
expressions = ensure_list(expressions)
65100
for expression_metadata in expressions:
66101
if isinstance(expression_metadata, tuple):
@@ -77,11 +112,8 @@ def make_python_env(
77112
if name not in macros:
78113
continue
79114

80-
# If this macro has been seen before as a non-metadata macro, prioritize that
81-
used_macros[name] = (
82-
macros[name],
83-
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
84-
)
115+
used_macros[name] = (macros[name], _is_metadata_macro(name, is_metadata))
116+
85117
if name in (c.VAR, c.BLUEPRINT_VAR):
86118
args = macro_func_or_var.this.expressions
87119
if len(args) < 1:
@@ -96,20 +128,22 @@ def make_python_env(
96128
)
97129

98130
var_name = args[0].this.lower()
99-
used_variables[var_name] = used_variables.get(var_name, True) and is_metadata
131+
used_variables[var_name] = _is_metadata_var(
132+
name, macro_func_or_var, is_metadata
133+
)
100134
else:
101-
for var_ref in _extract_macro_func_variable_references(macro_func_or_var):
135+
var_refs, _expr_under_metadata_macro_func = (
136+
_extract_macro_func_variable_references(macro_func_or_var, is_metadata)
137+
)
138+
expr_under_metadata_macro_func.update(_expr_under_metadata_macro_func)
139+
for var_ref in var_refs:
102140
macro_funcs_by_used_var[var_ref].add(name)
103141
elif macro_func_or_var.__class__ is d.MacroVar:
104142
name = macro_func_or_var.name.lower()
105143
if name in macros:
106-
# If this macro has been seen before as a non-metadata macro, prioritize that
107-
used_macros[name] = (
108-
macros[name],
109-
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
110-
)
144+
used_macros[name] = (macros[name], _is_metadata_macro(name, is_metadata))
111145
elif name in variables or name in blueprint_variables:
112-
used_variables[name] = used_variables.get(name, True) and is_metadata
146+
used_variables[name] = _is_metadata_var(name, macro_func_or_var, is_metadata)
113147
elif (
114148
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
115149
) and "@" in macro_func_or_var.name:
@@ -118,8 +152,8 @@ def make_python_env(
118152
):
119153
var_name = braced_identifier or identifier
120154
if var_name in variables or var_name in blueprint_variables:
121-
used_variables[var_name] = (
122-
used_variables.get(var_name, True) and is_metadata
155+
used_variables[var_name] = _is_metadata_var(
156+
var_name, macro_func_or_var, is_metadata
123157
)
124158

125159
for macro_ref in jinja_macro_references or set():
@@ -150,8 +184,12 @@ def make_python_env(
150184
)
151185

152186

153-
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
187+
def _extract_macro_func_variable_references(
188+
macro_func: exp.Expression,
189+
is_metadata: bool,
190+
) -> t.Tuple[t.Set[str], t.Dict[int, bool]]:
154191
references = set()
192+
expr_under_metadata_macro_func = {}
155193

156194
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
157195
# they will be handled in a separate call of _extract_macro_func_variable_references.
@@ -169,20 +207,23 @@ def _prune_nested_macro_func(expression: exp.Expression) -> bool:
169207

170208
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string:
171209
references.add(args[0].this.lower())
210+
expr_under_metadata_macro_func[id(n)] = is_metadata
172211
elif isinstance(n, d.MacroVar):
173212
references.add(n.name.lower())
213+
expr_under_metadata_macro_func[id(n)] = is_metadata
174214
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:
175215
references.update(
176216
(braced_identifier or identifier).lower()
177217
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(n.name)
178218
)
219+
expr_under_metadata_macro_func[id(n)] = is_metadata
179220

180-
return references
221+
return (references, expr_under_metadata_macro_func)
181222

182223

183224
def _add_variables_to_python_env(
184225
python_env: t.Dict[str, Executable],
185-
used_variables: t.Dict[str, bool],
226+
used_variables: t.Dict[str, t.Optional[bool]],
186227
variables: t.Optional[t.Dict[str, t.Any]],
187228
strict_resolution: bool = True,
188229
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
@@ -197,14 +238,18 @@ def _add_variables_to_python_env(
197238
blueprint_variables=blueprint_variables,
198239
)
199240
for var_name, is_metadata in python_used_variables.items():
200-
used_variables[var_name] = used_variables.get(var_name, True) and is_metadata
241+
used_variables[var_name] = is_metadata and used_variables.get(var_name)
201242

202243
# Variables are treated as metadata when:
203244
# - They are only referenced in metadata-only contexts, such as `audits (...)`, virtual statements, etc
204245
# - They are only referenced in metadata-only macros, either as their arguments or within their definitions
205246
metadata_used_variables = set()
206247
for used_var, macro_names in (macro_funcs_by_used_var or {}).items():
207-
if used_variables.get(used_var) or all(
248+
used_var_is_metadata = used_variables.get(used_var)
249+
if used_var_is_metadata is False:
250+
continue
251+
252+
if used_var_is_metadata or all(
208253
name in python_env and python_env[name].is_metadata for name in macro_names
209254
):
210255
metadata_used_variables.add(used_var)

tests/core/test_model.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10050,30 +10050,37 @@ def test_vars_are_taken_into_account_when_propagating_metadata_status(tmp_path:
1005010050
test_model.parent.mkdir(parents=True, exist_ok=True)
1005110051
test_model.write_text(
1005210052
"MODEL (name test_model, kind FULL, blueprints ((v4 := 4, v5 := 5)));"
10053-
"@m1_with_var();" # metadata macro, references v1 internally => v1 metadata
10054-
"@m2_without_var(@v2, @v3);" # metadata macro => v2 metadata, v3 metadata
10055-
"@m3_without_var(@v3);" # non-metadata macro, references v4 => v3, v4 are not metadata
10053+
"@m1_metadata_references_v1();" # metadata macro, references v1 internally => v1 metadata
10054+
"@m2_metadata_does_not_reference_var(@v2, @v3);" # metadata macro => v2 metadata, v3 metadata
10055+
"@m3_non_metadata_references_v4(@v3);" # non-metadata macro, references v4 => v3, v4 are not metadata
1005610056
"SELECT 1 AS c;"
10057+
"@m2_metadata_does_not_reference_var(@v6);" # metadata macro => v6 is metadata
10058+
"@m4_non_metadata_references_v6();" # non-metadata macro, references v6 => v6 is not metadata
1005710059
"ON_VIRTUAL_UPDATE_BEGIN;"
10058-
"@m3_without_var(@v5);" # non-metadata macro, metadata context => v5 metadata
10060+
"@m3_non_metadata_references_v4(@v5);" # non-metadata macro, metadata expression => v5 metadata
1005910061
"ON_VIRTUAL_UPDATE_END;"
1006010062
)
1006110063

1006210064
macro_code = """
1006310065
from sqlmesh import macro
1006410066
1006510067
@macro(metadata_only=True)
10066-
def m1_with_var(evaluator):
10068+
def m1_metadata_references_v1(evaluator):
1006710069
evaluator.var("v1")
1006810070
return None
1006910071
1007010072
@macro(metadata_only=True)
10071-
def m2_without_var(evaluator, *args):
10073+
def m2_metadata_does_not_reference_var(evaluator, *args):
1007210074
return None
1007310075
1007410076
@macro()
10075-
def m3_without_var(evaluator, *args):
10077+
def m3_non_metadata_references_v4(evaluator, *args):
1007610078
evaluator.var("v4")
10079+
return None
10080+
10081+
@macro()
10082+
def m4_non_metadata_references_v6(evaluator):
10083+
evaluator.var("v6")
1007710084
return None"""
1007810085

1007910086
test_macros = tmp_path / "macros/test_macros.py"
@@ -10083,23 +10090,24 @@ def m3_without_var(evaluator, *args):
1008310090
ctx = Context(
1008410091
config=Config(
1008510092
model_defaults=ModelDefaultsConfig(dialect="duckdb"),
10086-
variables={"v1": 1, "v2": 2, "v3": 3},
10093+
variables={"v1": 1, "v2": 2, "v3": 3, "v6": 6},
1008710094
),
1008810095
paths=tmp_path,
1008910096
)
1009010097
model = ctx.get_model("test_model")
1009110098

1009210099
python_env = model.python_env
1009310100

10094-
assert len(python_env) == 7
10095-
assert "m1_with_var" in python_env
10096-
assert "m2_without_var" in python_env
10097-
assert "m3_without_var" in python_env
10101+
assert len(python_env) == 8
10102+
assert "m1_metadata_references_v1" in python_env
10103+
assert "m2_metadata_does_not_reference_var" in python_env
10104+
assert "m3_non_metadata_references_v4" in python_env
10105+
assert "m4_non_metadata_references_v6" in python_env
1009810106

1009910107
variables = python_env.get(c.SQLMESH_VARS)
1010010108
metadata_variables = python_env.get(c.SQLMESH_VARS_METADATA)
1010110109

10102-
assert variables == Executable.value({"v1": 1, "v3": 3})
10110+
assert variables == Executable.value({"v1": 1, "v3": 3, "v6": 6})
1010310111
assert metadata_variables == Executable.value({"v2": 2}, is_metadata=True)
1010410112

1010510113
blueprint_variables = python_env.get(c.SQLMESH_BLUEPRINT_VARS)
@@ -10115,28 +10123,68 @@ def m3_without_var(evaluator, *args):
1011510123
assert macro_evaluator.locals == {
1011610124
"runtime_stage": "loading",
1011710125
"default_catalog": None,
10118-
c.SQLMESH_VARS: {"v1": 1, "v3": 3},
10126+
c.SQLMESH_VARS: {"v1": 1, "v3": 3, "v6": 6},
1011910127
c.SQLMESH_VARS_METADATA: {"v2": 2},
1012010128
c.SQLMESH_BLUEPRINT_VARS: {"v4": exp.Literal.number("4")},
1012110129
c.SQLMESH_BLUEPRINT_VARS_METADATA: {"v5": exp.Literal.number("5")},
1012210130
}
1012310131
assert macro_evaluator.var("v1") == 1
1012410132
assert macro_evaluator.var("v2") == 2
1012510133
assert macro_evaluator.var("v3") == 3
10134+
assert macro_evaluator.var("v6") == 6
1012610135
assert macro_evaluator.blueprint_var("v4") == exp.Literal.number("4")
1012710136
assert macro_evaluator.blueprint_var("v5") == exp.Literal.number("5")
1012810137

1012910138
query_with_vars = macro_evaluator.transform(
10130-
parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3]))
10139+
parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3, 6]))
1013110140
)
10132-
assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3"
10141+
assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6"
1013310142

1013410143
query_with_blueprint_vars = macro_evaluator.transform(
1013510144
parse_one("SELECT " + ", ".join(f"@v{var}, @BLUEPRINT_VAR('v{var}')" for var in [4, 5]))
1013610145
)
1013710146
assert t.cast(exp.Expression, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5"
1013810147

1013910148

10149+
def test_variable_mentioned_in_both_metadata_and_non_metadata_macro(tmp_path: Path) -> None:
10150+
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)
10151+
10152+
test_model = tmp_path / "models/test_model.sql"
10153+
test_model.parent.mkdir(parents=True, exist_ok=True)
10154+
test_model.write_text(
10155+
"MODEL (name test_model, kind FULL); @m1_references_v_metadata(); SELECT @m2_references_v_non_metadata() AS c;"
10156+
)
10157+
10158+
macro_code = """
10159+
from sqlmesh import macro
10160+
10161+
@macro(metadata_only=True)
10162+
def m1_references_v_metadata(evaluator):
10163+
evaluator.var("v")
10164+
return None
10165+
10166+
@macro()
10167+
def m2_references_v_non_metadata(evaluator):
10168+
evaluator.var("v")
10169+
return None"""
10170+
10171+
test_macros = tmp_path / "macros/test_macros.py"
10172+
test_macros.parent.mkdir(parents=True, exist_ok=True)
10173+
test_macros.write_text(macro_code)
10174+
10175+
ctx = Context(
10176+
config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}),
10177+
paths=tmp_path,
10178+
)
10179+
model = ctx.get_model("test_model")
10180+
10181+
python_env = model.python_env
10182+
10183+
assert len(python_env) == 3
10184+
assert set(python_env) > {"m1_references_v_metadata", "m2_references_v_non_metadata"}
10185+
assert python_env.get(c.SQLMESH_VARS) == Executable.value({"v": 1})
10186+
10187+
1014010188
def test_non_metadata_object_takes_precedence_over_metadata_only_object(tmp_path: Path) -> None:
1014110189
init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY)
1014210190

@@ -11167,4 +11215,4 @@ def test_extract_macro_func_variable_references(macro_func: str, variables: t.Se
1116711215
from sqlmesh.core.model.common import _extract_macro_func_variable_references
1116811216

1116911217
macro_func_ast = parse_one(macro_func)
11170-
assert _extract_macro_func_variable_references(macro_func_ast) == variables
11218+
assert _extract_macro_func_variable_references(macro_func_ast, True)[0] == variables

0 commit comments

Comments
 (0)