Skip to content

Commit 04d199b

Browse files
authored
Fix!: Inconsistent behaviour when a macro returns a list containing a single array vs multiple arrays (#5037)
1 parent 8f2947b commit 04d199b

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

sqlmesh/core/macros.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,37 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] |
361361
return None
362362

363363
if isinstance(result, (tuple, list)):
364-
return [self.parse_one(item) for item in result if item is not None]
365-
return self.parse_one(result)
364+
result = [self.parse_one(item) for item in result if item is not None]
365+
366+
if (
367+
len(result) == 1
368+
and isinstance(result[0], (exp.Array, exp.Tuple))
369+
and node.find_ancestor(MacroFunc)
370+
):
371+
"""
372+
if:
373+
- the output of evaluating this node is being passed as an argument to another macro function
374+
- and that output is something that _norm_var_arg_lambda() will unpack into varargs
375+
> (a list containing a single item of type exp.Tuple/exp.Array)
376+
then we will get inconsistent behaviour depending on if this node emits a list with a single item vs multiple items.
377+
378+
In the first case, emitting a list containing a single array item will cause that array to get unpacked and its *members* passed to the calling macro
379+
In the second case, emitting a list containing multiple array items will cause each item to get passed as-is to the calling macro
380+
381+
To prevent this inconsistency, we wrap this node output in an exp.Array so that _norm_var_arg_lambda() can "unpack" that into the
382+
actual argument we want to pass to the parent macro function
383+
384+
Note we only do this for evaluation results that get passed as an argument to another macro, because when the final
385+
result is given to something like SELECT, we still want that to be unpacked into a list of items like:
386+
- SELECT ARRAY(1), ARRAY(2)
387+
rather than a single item like:
388+
- SELECT ARRAY(ARRAY(1), ARRAY(2))
389+
"""
390+
result = [exp.Array(expressions=result)]
391+
else:
392+
result = self.parse_one(result)
393+
394+
return result
366395

367396
def eval_expression(self, node: t.Any) -> t.Any:
368397
"""Converts a SQLGlot expression into executable Python code and evals it.

tests/core/test_macros.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,24 @@ def test_ast_correctness(macro_evaluator):
363363
"SELECT column LIKE a OR column LIKE b OR column LIKE c",
364364
{},
365365
),
366+
("SELECT @REDUCE([1], (x, y) -> x + y)", "SELECT 1", {}),
367+
("SELECT @REDUCE([1, 2], (x, y) -> x + y)", "SELECT 1 + 2", {}),
368+
("SELECT @REDUCE([[1]], (x, y) -> x + y)", "SELECT ARRAY(1)", {}),
369+
("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", "SELECT ARRAY(1, 2)", {}),
366370
(
367371
"""select @EACH([a, b, c], x -> column like x AS @SQL('@{x}_y', 'Identifier')), @x""",
368372
"SELECT column LIKE a AS a_y, column LIKE b AS b_y, column LIKE c AS c_y, '3'",
369373
{"x": "3"},
370374
),
375+
("SELECT @EACH([1], a -> [@a])", "SELECT ARRAY(1)", {}),
376+
("SELECT @EACH([1, 2], a -> [@a])", "SELECT ARRAY(1), ARRAY(2)", {}),
377+
("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1)", {}),
378+
(
379+
"SELECT @REDUCE(@EACH([1, 2], a -> [@a]), (x, y) -> x + y)",
380+
"SELECT ARRAY(1) + ARRAY(2)",
381+
{},
382+
),
383+
("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}),
371384
(
372385
"""@WITH(@do_with) all_cities as (select * from city) select all_cities""",
373386
"WITH all_cities AS (SELECT * FROM city) SELECT all_cities",

0 commit comments

Comments
 (0)