Skip to content

Commit 8d8fc06

Browse files
committed
Fix!: mark vars referenced in metadata macros as metadata
1 parent 537a311 commit 8d8fc06

File tree

7 files changed

+203
-40
lines changed

7 files changed

+203
-40
lines changed

sqlmesh/core/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@
8080
DEFAULT_SCHEMA = "default"
8181

8282
SQLMESH_VARS = "__sqlmesh__vars__"
83+
SQLMESH_VARS_METADATA = "__sqlmesh__vars__metadata__"
8384
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
85+
SQLMESH_BLUEPRINT_VARS_METADATA = "__sqlmesh__blueprint__vars__metadata__"
8486

8587
VAR = "var"
8688
BLUEPRINT_VAR = "blueprint_var"

sqlmesh/core/macros.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ def __init__(
210210
self.macros[normalize_macro_name(k)] = self.env[k]
211211
elif v.is_value:
212212
value = self.env[k]
213-
if k in (c.SQLMESH_VARS, c.SQLMESH_BLUEPRINT_VARS):
213+
if k in (
214+
c.SQLMESH_VARS,
215+
c.SQLMESH_VARS_METADATA,
216+
c.SQLMESH_BLUEPRINT_VARS,
217+
c.SQLMESH_BLUEPRINT_VARS_METADATA,
218+
):
214219
value = {
215220
var_name: (
216221
self.parse_one(var_value.sql)
@@ -528,17 +533,25 @@ def views(self) -> t.List[str]:
528533

529534
def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
530535
"""Returns the value of the specified variable, or the default value if it doesn't exist."""
531-
return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default)
536+
return (
537+
self.locals.get(c.SQLMESH_VARS) or self.locals.get(c.SQLMESH_VARS_METADATA) or {}
538+
).get(var_name.lower(), default)
532539

533540
def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]:
534541
"""Returns the value of the specified blueprint variable, or the default value if it doesn't exist."""
535-
return (self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}).get(var_name.lower(), default)
542+
return (
543+
self.locals.get(c.SQLMESH_BLUEPRINT_VARS)
544+
or self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA)
545+
or {}
546+
).get(var_name.lower(), default)
536547

537548
@property
538549
def variables(self) -> t.Dict[str, t.Any]:
539550
return {
540551
**self.locals.get(c.SQLMESH_VARS, {}),
552+
**self.locals.get(c.SQLMESH_VARS_METADATA, {}),
541553
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}),
554+
**self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
542555
}
543556

544557
def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any:

sqlmesh/core/model/common.py

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66

77
from astor import to_source
8+
from collections import defaultdict
89
from difflib import get_close_matches
910
from sqlglot import exp
1011
from sqlglot.helper import ensure_list
@@ -28,7 +29,7 @@
2829
from sqlmesh.utils import registry_decorator
2930
from sqlmesh.utils.jinja import MacroReference
3031

31-
MacroCallable = registry_decorator
32+
MacroCallable = t.Union[Executable, registry_decorator]
3233

3334

3435
def make_python_env(
@@ -48,13 +49,17 @@ def make_python_env(
4849
dialect: DialectType = None,
4950
) -> t.Dict[str, Executable]:
5051
python_env = {} if python_env is None else python_env
51-
variables = variables or {}
5252
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
53-
used_macros: t.Dict[
54-
str,
55-
t.Tuple[t.Union[Executable | MacroCallable], t.Optional[bool]],
56-
] = {}
57-
used_variables = (used_variables or set()).copy()
53+
54+
variables = variables or {}
55+
blueprint_variables = blueprint_variables or {}
56+
57+
used_macros: t.Dict[str, t.Tuple[MacroCallable, t.Optional[bool]]] = {}
58+
used_variable_referenced_in_metadata_expression = dict.fromkeys(used_variables or set(), False)
59+
60+
# For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
61+
# v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
62+
macro_funcs_by_used_var: t.DefaultDict[str, t.Set[str]] = defaultdict(set)
5863

5964
expressions = ensure_list(expressions)
6065
for expression_metadata in expressions:
@@ -77,16 +82,27 @@ def make_python_env(
7782
macros[name],
7883
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
7984
)
80-
if name == c.VAR:
85+
if name in (c.VAR, c.BLUEPRINT_VAR):
8186
args = macro_func_or_var.this.expressions
8287
if len(args) < 1:
83-
raise_config_error("Macro VAR requires at least one argument", path)
88+
raise_config_error(
89+
f"Macro {name.upper()} requires at least one argument", path
90+
)
91+
8492
if not args[0].is_string:
8593
raise_config_error(
8694
f"The variable name must be a string literal, '{args[0].sql()}' was given instead",
8795
path,
8896
)
89-
used_variables.add(args[0].this.lower())
97+
98+
var_name = args[0].this.lower()
99+
used_variable_referenced_in_metadata_expression[var_name] = (
100+
used_variable_referenced_in_metadata_expression.get(var_name, True)
101+
and bool(is_metadata)
102+
)
103+
else:
104+
for var_ref in _extract_macro_func_variable_references(macro_func_or_var):
105+
macro_funcs_by_used_var[var_ref].add(name)
90106
elif macro_func_or_var.__class__ is d.MacroVar:
91107
name = macro_func_or_var.name.lower()
92108
if name in macros:
@@ -95,17 +111,23 @@ def make_python_env(
95111
macros[name],
96112
used_macros.get(name, (None, is_metadata))[1] and is_metadata,
97113
)
98-
elif name in variables:
99-
used_variables.add(name)
114+
elif name in variables or name in blueprint_variables:
115+
used_variable_referenced_in_metadata_expression[name] = (
116+
used_variable_referenced_in_metadata_expression.get(name, True)
117+
and bool(is_metadata)
118+
)
100119
elif (
101120
isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL))
102121
) and "@" in macro_func_or_var.name:
103122
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(
104123
macro_func_or_var.name
105124
):
106125
var_name = braced_identifier or identifier
107-
if var_name in variables:
108-
used_variables.add(var_name)
126+
if var_name in variables or var_name in blueprint_variables:
127+
used_variable_referenced_in_metadata_expression[var_name] = (
128+
used_variable_referenced_in_metadata_expression.get(var_name, True)
129+
and bool(is_metadata)
130+
)
109131

110132
for macro_ref in jinja_macro_references or set():
111133
if macro_ref.package is None and macro_ref.name in macros:
@@ -126,41 +148,97 @@ def make_python_env(
126148
python_env.update(serialize_env(env, path=module_path))
127149
return _add_variables_to_python_env(
128150
python_env,
129-
used_variables,
151+
used_variable_referenced_in_metadata_expression,
130152
variables,
131153
blueprint_variables=blueprint_variables,
132154
dialect=dialect,
133155
strict_resolution=strict_resolution,
156+
macro_funcs_by_used_var=macro_funcs_by_used_var,
134157
)
135158

136159

160+
def _extract_macro_func_variable_references(macro_func: exp.Expression) -> t.Set[str]:
161+
references = set()
162+
163+
for n in macro_func.walk():
164+
if n is macro_func:
165+
continue
166+
167+
# Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168+
# they will be handled in a separate call of _extract_macro_func_variable_references.
169+
if isinstance(n, d.MacroFunc):
170+
this = n.this
171+
if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and this.expressions:
172+
references.add(this.expressions[0].this.lower())
173+
elif isinstance(n, d.MacroVar):
174+
references.add(n.name.lower())
175+
elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name:
176+
references.update(
177+
(braced_identifier or identifier).lower()
178+
for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(n.name)
179+
)
180+
181+
return references
182+
183+
137184
def _add_variables_to_python_env(
138185
python_env: t.Dict[str, Executable],
139-
used_variables: t.Optional[t.Set[str]],
186+
used_variable_referenced_in_metadata_expression: t.Dict[str, bool],
140187
variables: t.Optional[t.Dict[str, t.Any]],
141188
strict_resolution: bool = True,
142189
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
143190
dialect: DialectType = None,
191+
macro_funcs_by_used_var: t.Optional[t.DefaultDict[str, t.Set[str]]] = None,
144192
) -> t.Dict[str, Executable]:
145-
_, python_used_variables = parse_dependencies(
193+
_, python_used_variable_referenced_in_metadata_expression = parse_dependencies(
146194
python_env,
147195
None,
148196
strict_resolution=strict_resolution,
149197
variables=variables,
150198
blueprint_variables=blueprint_variables,
151199
)
152-
used_variables = (used_variables or set()) | python_used_variables
200+
for var_name, is_metadata in python_used_variable_referenced_in_metadata_expression.items():
201+
used_variable_referenced_in_metadata_expression[var_name] = (
202+
used_variable_referenced_in_metadata_expression.get(var_name, True) and is_metadata
203+
)
204+
205+
metadata_used_variables = set()
206+
for used_var, macro_names in (macro_funcs_by_used_var or {}).items():
207+
if used_variable_referenced_in_metadata_expression.get(used_var) or all(
208+
name in python_env and python_env[name].is_metadata for name in macro_names
209+
):
210+
metadata_used_variables.add(used_var)
211+
212+
used_variables = set(used_variable_referenced_in_metadata_expression)
213+
non_metadata_used_variables = used_variables - metadata_used_variables
214+
215+
metadata_variables = {
216+
k: v for k, v in (variables or {}).items() if k in metadata_used_variables
217+
}
218+
variables = {k: v for k, v in (variables or {}).items() if k in non_metadata_used_variables}
153219

154-
variables = {k: v for k, v in (variables or {}).items() if k in used_variables}
155220
if variables:
156221
python_env[c.SQLMESH_VARS] = Executable.value(variables)
222+
if metadata_variables:
223+
python_env[c.SQLMESH_VARS_METADATA] = Executable.value(metadata_variables, is_metadata=True)
157224

158225
if blueprint_variables:
226+
metadata_blueprint_variables = {
227+
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
228+
for k, v in blueprint_variables.items()
229+
if k in metadata_used_variables
230+
}
159231
blueprint_variables = {
160232
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
161233
for k, v in blueprint_variables.items()
234+
if k in non_metadata_used_variables
162235
}
163-
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)
236+
if blueprint_variables:
237+
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(blueprint_variables)
238+
if metadata_blueprint_variables:
239+
python_env[c.SQLMESH_BLUEPRINT_VARS_METADATA] = Executable.value(
240+
blueprint_variables, is_metadata=True
241+
)
164242

165243
return python_env
166244

@@ -171,7 +249,7 @@ def parse_dependencies(
171249
strict_resolution: bool = True,
172250
variables: t.Optional[t.Dict[str, t.Any]] = None,
173251
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
174-
) -> t.Tuple[t.Set[str], t.Set[str]]:
252+
) -> t.Tuple[t.Set[str], t.Dict[str, bool]]:
175253
"""
176254
Parses the source of a model function and finds upstream table dependencies
177255
and referenced variables based on calls to context / evaluator.
@@ -185,7 +263,8 @@ def parse_dependencies(
185263
blueprint_variables: The blueprint variables available to the python environment.
186264
187265
Returns:
188-
A tuple containing the set of upstream table dependencies and the set of referenced variables.
266+
A tuple containing the set of upstream table dependencies and a mapping of
267+
the referenced variables associated with their metadata status.
189268
"""
190269

191270
class VariableResolutionContext:
@@ -203,12 +282,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona
203282
local_env = dict.fromkeys(("context", "evaluator"), VariableResolutionContext)
204283

205284
depends_on = set()
206-
used_variables = set()
285+
used_variable_referenced_in_metadata_expression: t.Dict[str, bool] = {}
207286

208287
for executable in python_env.values():
209288
if not executable.is_definition:
210289
continue
290+
291+
is_metadata = executable.is_metadata
211292
for node in ast.walk(ast.parse(executable.payload)):
293+
next_variables = set()
294+
212295
if isinstance(node, ast.Call):
213296
func = node.func
214297
if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name):
@@ -239,26 +322,35 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
239322

240323
if func.value.id == "context" and func.attr in ("table", "resolve_table"):
241324
depends_on.add(get_first_arg("model_name"))
242-
elif func.value.id in ("context", "evaluator") and func.attr == c.VAR:
243-
used_variables.add(get_first_arg("var_name").lower())
325+
elif func.value.id in ("context", "evaluator") and func.attr in (
326+
c.VAR,
327+
c.BLUEPRINT_VAR,
328+
):
329+
next_variables.add(get_first_arg("var_name").lower())
244330
elif (
245331
isinstance(node, ast.Attribute)
246332
and isinstance(node.value, ast.Name)
247333
and node.value.id in ("context", "evaluator")
248334
and node.attr == c.GATEWAY
249335
):
250336
# Check whether the gateway attribute is referenced.
251-
used_variables.add(c.GATEWAY)
337+
next_variables.add(c.GATEWAY)
252338
elif isinstance(node, ast.FunctionDef) and node.name == entrypoint:
253-
used_variables.update(
339+
next_variables.update(
254340
[
255341
arg.arg
256342
for arg in [*node.args.args, *node.args.kwonlyargs]
257343
if arg.arg != "context"
258344
]
259345
)
260346

261-
return depends_on, used_variables
347+
for var_name in next_variables:
348+
used_variable_referenced_in_metadata_expression[var_name] = (
349+
used_variable_referenced_in_metadata_expression.get(var_name, True)
350+
and bool(is_metadata)
351+
)
352+
353+
return depends_on, used_variable_referenced_in_metadata_expression
262354

263355

264356
def validate_extra_and_required_fields(

sqlmesh/core/model/definition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1775,11 +1775,15 @@ def render(
17751775
execution_time = to_datetime(execution_time or c.EPOCH)
17761776

17771777
variables = env.get(c.SQLMESH_VARS, {})
1778+
variables.update(env.get(c.SQLMESH_VARS_METADATA, {}))
17781779
variables.update(kwargs.pop("variables", {}))
17791780

17801781
blueprint_variables = {
17811782
k: d.parse_one(v.sql, dialect=self.dialect) if isinstance(v, SqlValue) else v
1782-
for k, v in env.get(c.SQLMESH_BLUEPRINT_VARS, {}).items()
1783+
for k, v in {
1784+
**env.get(c.SQLMESH_BLUEPRINT_VARS, {}),
1785+
**env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}),
1786+
}.items()
17831787
}
17841788
try:
17851789
kwargs = {

sqlmesh/core/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def _resolve_table(table: str | exp.Table) -> str:
231231

232232
if variables:
233233
macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables)
234+
macro_evaluator.locals.setdefault(c.SQLMESH_VARS_METADATA, {})
234235

235236
for definition in self._macro_definitions:
236237
try:

sqlmesh/utils/jinja.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def extract_macro_references_and_variables(
229229
)
230230

231231
for call_name, node in extract_call_names(jinja_str):
232-
if call_name[0] == c.VAR:
232+
if call_name[0] in (c.VAR, c.BLUEPRINT_VAR):
233233
assert isinstance(node, nodes.Call)
234234
args = [jinja_call_arg_name(arg) for arg in node.args]
235235
if args and args[0]:

0 commit comments

Comments
 (0)