5
5
from pathlib import Path
6
6
7
7
from astor import to_source
8
+ from collections import defaultdict
8
9
from difflib import get_close_matches
9
10
from sqlglot import exp
10
11
from sqlglot .helper import ensure_list
28
29
from sqlmesh .utils import registry_decorator
29
30
from sqlmesh .utils .jinja import MacroReference
30
31
31
- MacroCallable = registry_decorator
32
+ MacroCallable = t . Union [ Executable , registry_decorator ]
32
33
33
34
34
35
def make_python_env (
@@ -48,13 +49,17 @@ def make_python_env(
48
49
dialect : DialectType = None ,
49
50
) -> t .Dict [str , Executable ]:
50
51
python_env = {} if python_env is None else python_env
51
- variables = variables or {}
52
52
env : t .Dict [str , t .Tuple [t .Any , t .Optional [bool ]]] = {}
53
- used_macros : t .Dict [
54
- str ,
55
- t .Tuple [t .Union [Executable | MacroCallable ], t .Optional [bool ]],
56
- ] = {}
57
- used_variables = (used_variables or set ()).copy ()
53
+
54
+ variables = variables or {}
55
+ blueprint_variables = blueprint_variables or {}
56
+
57
+ used_macros : t .Dict [str , t .Tuple [MacroCallable , t .Optional [bool ]]] = {}
58
+ used_variable_referenced_in_metadata_expression = dict .fromkeys (used_variables or set (), False )
59
+
60
+ # For an expression like @foo(@v1, @bar(@v1, @v2), @v3), the following mapping would be:
61
+ # v1 -> {"foo", "bar"}, v2 -> {"bar"}, v3 -> "foo"
62
+ macro_funcs_by_used_var : t .DefaultDict [str , t .Set [str ]] = defaultdict (set )
58
63
59
64
expressions = ensure_list (expressions )
60
65
for expression_metadata in expressions :
@@ -77,16 +82,27 @@ def make_python_env(
77
82
macros [name ],
78
83
used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
79
84
)
80
- if name == c .VAR :
85
+ if name in ( c .VAR , c . BLUEPRINT_VAR ) :
81
86
args = macro_func_or_var .this .expressions
82
87
if len (args ) < 1 :
83
- raise_config_error ("Macro VAR requires at least one argument" , path )
88
+ raise_config_error (
89
+ f"Macro { name .upper ()} requires at least one argument" , path
90
+ )
91
+
84
92
if not args [0 ].is_string :
85
93
raise_config_error (
86
94
f"The variable name must be a string literal, '{ args [0 ].sql ()} ' was given instead" ,
87
95
path ,
88
96
)
89
- used_variables .add (args [0 ].this .lower ())
97
+
98
+ var_name = args [0 ].this .lower ()
99
+ used_variable_referenced_in_metadata_expression [var_name ] = (
100
+ used_variable_referenced_in_metadata_expression .get (var_name , True )
101
+ and bool (is_metadata )
102
+ )
103
+ else :
104
+ for var_ref in _extract_macro_func_variable_references (macro_func_or_var ):
105
+ macro_funcs_by_used_var [var_ref ].add (name )
90
106
elif macro_func_or_var .__class__ is d .MacroVar :
91
107
name = macro_func_or_var .name .lower ()
92
108
if name in macros :
@@ -95,17 +111,23 @@ def make_python_env(
95
111
macros [name ],
96
112
used_macros .get (name , (None , is_metadata ))[1 ] and is_metadata ,
97
113
)
98
- elif name in variables :
99
- used_variables .add (name )
114
+ elif name in variables or name in blueprint_variables :
115
+ used_variable_referenced_in_metadata_expression [name ] = (
116
+ used_variable_referenced_in_metadata_expression .get (name , True )
117
+ and bool (is_metadata )
118
+ )
100
119
elif (
101
120
isinstance (macro_func_or_var , (exp .Identifier , d .MacroStrReplace , d .MacroSQL ))
102
121
) and "@" in macro_func_or_var .name :
103
122
for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (
104
123
macro_func_or_var .name
105
124
):
106
125
var_name = braced_identifier or identifier
107
- if var_name in variables :
108
- used_variables .add (var_name )
126
+ if var_name in variables or var_name in blueprint_variables :
127
+ used_variable_referenced_in_metadata_expression [var_name ] = (
128
+ used_variable_referenced_in_metadata_expression .get (var_name , True )
129
+ and bool (is_metadata )
130
+ )
109
131
110
132
for macro_ref in jinja_macro_references or set ():
111
133
if macro_ref .package is None and macro_ref .name in macros :
@@ -126,41 +148,97 @@ def make_python_env(
126
148
python_env .update (serialize_env (env , path = module_path ))
127
149
return _add_variables_to_python_env (
128
150
python_env ,
129
- used_variables ,
151
+ used_variable_referenced_in_metadata_expression ,
130
152
variables ,
131
153
blueprint_variables = blueprint_variables ,
132
154
dialect = dialect ,
133
155
strict_resolution = strict_resolution ,
156
+ macro_funcs_by_used_var = macro_funcs_by_used_var ,
134
157
)
135
158
136
159
160
+ def _extract_macro_func_variable_references (macro_func : exp .Expression ) -> t .Set [str ]:
161
+ references = set ()
162
+
163
+ for n in macro_func .walk ():
164
+ if n is macro_func :
165
+ continue
166
+
167
+ # Don't descend into nested MacroFunc nodes besides @VAR() and @BLUEPRINT_VAR(), because
168
+ # they will be handled in a separate call of _extract_macro_func_variable_references.
169
+ if isinstance (n , d .MacroFunc ):
170
+ this = n .this
171
+ if this .name .lower () in (c .VAR , c .BLUEPRINT_VAR ) and this .expressions :
172
+ references .add (this .expressions [0 ].this .lower ())
173
+ elif isinstance (n , d .MacroVar ):
174
+ references .add (n .name .lower ())
175
+ elif isinstance (n , (exp .Identifier , d .MacroStrReplace , d .MacroSQL )) and "@" in n .name :
176
+ references .update (
177
+ (braced_identifier or identifier ).lower ()
178
+ for _ , identifier , braced_identifier , _ in MacroStrTemplate .pattern .findall (n .name )
179
+ )
180
+
181
+ return references
182
+
183
+
137
184
def _add_variables_to_python_env (
138
185
python_env : t .Dict [str , Executable ],
139
- used_variables : t .Optional [ t . Set [ str ] ],
186
+ used_variable_referenced_in_metadata_expression : t .Dict [ str , bool ],
140
187
variables : t .Optional [t .Dict [str , t .Any ]],
141
188
strict_resolution : bool = True ,
142
189
blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
143
190
dialect : DialectType = None ,
191
+ macro_funcs_by_used_var : t .Optional [t .DefaultDict [str , t .Set [str ]]] = None ,
144
192
) -> t .Dict [str , Executable ]:
145
- _ , python_used_variables = parse_dependencies (
193
+ _ , python_used_variable_referenced_in_metadata_expression = parse_dependencies (
146
194
python_env ,
147
195
None ,
148
196
strict_resolution = strict_resolution ,
149
197
variables = variables ,
150
198
blueprint_variables = blueprint_variables ,
151
199
)
152
- used_variables = (used_variables or set ()) | python_used_variables
200
+ for var_name , is_metadata in python_used_variable_referenced_in_metadata_expression .items ():
201
+ used_variable_referenced_in_metadata_expression [var_name ] = (
202
+ used_variable_referenced_in_metadata_expression .get (var_name , True ) and is_metadata
203
+ )
204
+
205
+ metadata_used_variables = set ()
206
+ for used_var , macro_names in (macro_funcs_by_used_var or {}).items ():
207
+ if used_variable_referenced_in_metadata_expression .get (used_var ) or all (
208
+ name in python_env and python_env [name ].is_metadata for name in macro_names
209
+ ):
210
+ metadata_used_variables .add (used_var )
211
+
212
+ used_variables = set (used_variable_referenced_in_metadata_expression )
213
+ non_metadata_used_variables = used_variables - metadata_used_variables
214
+
215
+ metadata_variables = {
216
+ k : v for k , v in (variables or {}).items () if k in metadata_used_variables
217
+ }
218
+ variables = {k : v for k , v in (variables or {}).items () if k in non_metadata_used_variables }
153
219
154
- variables = {k : v for k , v in (variables or {}).items () if k in used_variables }
155
220
if variables :
156
221
python_env [c .SQLMESH_VARS ] = Executable .value (variables )
222
+ if metadata_variables :
223
+ python_env [c .SQLMESH_VARS_METADATA ] = Executable .value (metadata_variables , is_metadata = True )
157
224
158
225
if blueprint_variables :
226
+ metadata_blueprint_variables = {
227
+ k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
228
+ for k , v in blueprint_variables .items ()
229
+ if k in metadata_used_variables
230
+ }
159
231
blueprint_variables = {
160
232
k : SqlValue (sql = v .sql (dialect = dialect )) if isinstance (v , exp .Expression ) else v
161
233
for k , v in blueprint_variables .items ()
234
+ if k in non_metadata_used_variables
162
235
}
163
- python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (blueprint_variables )
236
+ if blueprint_variables :
237
+ python_env [c .SQLMESH_BLUEPRINT_VARS ] = Executable .value (blueprint_variables )
238
+ if metadata_blueprint_variables :
239
+ python_env [c .SQLMESH_BLUEPRINT_VARS_METADATA ] = Executable .value (
240
+ blueprint_variables , is_metadata = True
241
+ )
164
242
165
243
return python_env
166
244
@@ -171,7 +249,7 @@ def parse_dependencies(
171
249
strict_resolution : bool = True ,
172
250
variables : t .Optional [t .Dict [str , t .Any ]] = None ,
173
251
blueprint_variables : t .Optional [t .Dict [str , t .Any ]] = None ,
174
- ) -> t .Tuple [t .Set [str ], t .Set [str ]]:
252
+ ) -> t .Tuple [t .Set [str ], t .Dict [str , bool ]]:
175
253
"""
176
254
Parses the source of a model function and finds upstream table dependencies
177
255
and referenced variables based on calls to context / evaluator.
@@ -185,7 +263,8 @@ def parse_dependencies(
185
263
blueprint_variables: The blueprint variables available to the python environment.
186
264
187
265
Returns:
188
- A tuple containing the set of upstream table dependencies and the set of referenced variables.
266
+ A tuple containing the set of upstream table dependencies and a mapping of
267
+ the referenced variables associated with their metadata status.
189
268
"""
190
269
191
270
class VariableResolutionContext :
@@ -203,12 +282,16 @@ def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optiona
203
282
local_env = dict .fromkeys (("context" , "evaluator" ), VariableResolutionContext )
204
283
205
284
depends_on = set ()
206
- used_variables = set ()
285
+ used_variable_referenced_in_metadata_expression : t . Dict [ str , bool ] = {}
207
286
208
287
for executable in python_env .values ():
209
288
if not executable .is_definition :
210
289
continue
290
+
291
+ is_metadata = executable .is_metadata
211
292
for node in ast .walk (ast .parse (executable .payload )):
293
+ next_variables = set ()
294
+
212
295
if isinstance (node , ast .Call ):
213
296
func = node .func
214
297
if not isinstance (func , ast .Attribute ) or not isinstance (func .value , ast .Name ):
@@ -239,26 +322,35 @@ def get_first_arg(keyword_arg_name: str) -> t.Any:
239
322
240
323
if func .value .id == "context" and func .attr in ("table" , "resolve_table" ):
241
324
depends_on .add (get_first_arg ("model_name" ))
242
- elif func .value .id in ("context" , "evaluator" ) and func .attr == c .VAR :
243
- used_variables .add (get_first_arg ("var_name" ).lower ())
325
+ elif func .value .id in ("context" , "evaluator" ) and func .attr in (
326
+ c .VAR ,
327
+ c .BLUEPRINT_VAR ,
328
+ ):
329
+ next_variables .add (get_first_arg ("var_name" ).lower ())
244
330
elif (
245
331
isinstance (node , ast .Attribute )
246
332
and isinstance (node .value , ast .Name )
247
333
and node .value .id in ("context" , "evaluator" )
248
334
and node .attr == c .GATEWAY
249
335
):
250
336
# Check whether the gateway attribute is referenced.
251
- used_variables .add (c .GATEWAY )
337
+ next_variables .add (c .GATEWAY )
252
338
elif isinstance (node , ast .FunctionDef ) and node .name == entrypoint :
253
- used_variables .update (
339
+ next_variables .update (
254
340
[
255
341
arg .arg
256
342
for arg in [* node .args .args , * node .args .kwonlyargs ]
257
343
if arg .arg != "context"
258
344
]
259
345
)
260
346
261
- return depends_on , used_variables
347
+ for var_name in next_variables :
348
+ used_variable_referenced_in_metadata_expression [var_name ] = (
349
+ used_variable_referenced_in_metadata_expression .get (var_name , True )
350
+ and bool (is_metadata )
351
+ )
352
+
353
+ return depends_on , used_variable_referenced_in_metadata_expression
262
354
263
355
264
356
def validate_extra_and_required_fields (
0 commit comments