Skip to content

Commit 4fa9992

Browse files
authored
Feat: correctly handle the generation of VALUES expressions using macros (#4975)
1 parent 9b26320 commit 4fa9992

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

sqlmesh/core/dialect.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,20 @@ def _parse_limit(
419419
return macro
420420

421421

422+
def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]:
423+
wrapped = self._match(TokenType.L_PAREN, advance=False)
424+
425+
# The base _parse_value method always constructs a Tuple instance. This is problematic when
426+
# generating values with a macro function, because it's impossible to tell whether the user's
427+
# intention was to construct a row or a column with the VALUES expression. To avoid this, we
428+
# amend the AST such that the Tuple is replaced by the macro function call itself.
429+
expr = self.__parse_value() # type: ignore
430+
if expr and not wrapped and isinstance(seq_get(expr.expressions, 0), MacroFunc):
431+
return expr.expressions[0]
432+
433+
return expr
434+
435+
422436
def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]:
423437
return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser()
424438

@@ -1063,6 +1077,7 @@ def extend_sqlglot() -> None:
10631077
_override(Parser, _parse_with)
10641078
_override(Parser, _parse_having)
10651079
_override(Parser, _parse_limit)
1080+
_override(Parser, _parse_value)
10661081
_override(Parser, _parse_lambda)
10671082
_override(Parser, _parse_types)
10681083
_override(Parser, _parse_if)

tests/core/test_macros.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,26 @@ def test_ast_correctness(macro_evaluator):
575575
"SELECT 3",
576576
{},
577577
),
578+
(
579+
"SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v)) ) AS v",
580+
"SELECT * FROM (VALUES (1), (2), (3)) AS v",
581+
{},
582+
),
583+
(
584+
"SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v))) ) AS v",
585+
"SELECT * FROM (VALUES ((1), (2), (3))) AS v",
586+
{},
587+
),
588+
(
589+
"SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v, @EVAL(@v + 1))) ) AS v",
590+
"SELECT * FROM (VALUES (1, 2), (2, 3), (3, 4)) AS v",
591+
{},
592+
),
593+
(
594+
"SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v, @EVAL(@v + 1)))) ) AS v",
595+
"SELECT * FROM (VALUES ((1, 2), (2, 3), (3, 4))) AS v",
596+
{},
597+
),
578598
],
579599
)
580600
def test_macro_functions(macro_evaluator: MacroEvaluator, assert_exp_eq, sql, expected, args):

0 commit comments

Comments
 (0)