Skip to content

Commit 3bbb819

Browse files
authored
Fix!: normalize blueprint variables (#5045)
1 parent b448d1c commit 3bbb819

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

sqlmesh/core/model/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _add_variables_to_python_env(
157157

158158
if blueprint_variables:
159159
blueprint_variables = {
160-
k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
160+
k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v
161161
for k, v in blueprint_variables.items()
162162
}
163163
python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value(
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
Normalizes blueprint variables, so Customer_Field is stored as customer_field in the `python_env`:
3+
4+
MODEL (
5+
...
6+
blueprints (
7+
Customer_Field := 1
8+
)
9+
);
10+
11+
SELECT
12+
@customer_field AS col
13+
"""
14+
15+
import json
16+
import logging
17+
from dataclasses import dataclass
18+
19+
from sqlglot import exp
20+
from sqlmesh.core.console import get_console
21+
from sqlmesh.utils.migration import index_text_type, blob_text_type
22+
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__"
28+
29+
30+
# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration
31+
@dataclass
32+
class SqlValue:
33+
"""A SQL string representing a generated SQLGlot AST."""
34+
35+
sql: str
36+
37+
38+
def migrate(state_sync, **kwargs): # type: ignore
39+
import pandas as pd
40+
41+
engine_adapter = state_sync.engine_adapter
42+
schema = state_sync.schema
43+
snapshots_table = "_snapshots"
44+
if schema:
45+
snapshots_table = f"{schema}.{snapshots_table}"
46+
47+
migration_needed = False
48+
new_snapshots = []
49+
50+
for (
51+
name,
52+
identifier,
53+
version,
54+
snapshot,
55+
kind_name,
56+
updated_ts,
57+
unpaused_ts,
58+
ttl_ms,
59+
unrestorable,
60+
) in engine_adapter.fetchall(
61+
exp.select(
62+
"name",
63+
"identifier",
64+
"version",
65+
"snapshot",
66+
"kind_name",
67+
"updated_ts",
68+
"unpaused_ts",
69+
"ttl_ms",
70+
"unrestorable",
71+
).from_(snapshots_table),
72+
quote_identifiers=True,
73+
):
74+
parsed_snapshot = json.loads(snapshot)
75+
node = parsed_snapshot["node"]
76+
python_env = node.get("python_env") or {}
77+
78+
migrate_snapshot = False
79+
80+
if blueprint_vars_executable := python_env.get(SQLMESH_BLUEPRINT_VARS):
81+
blueprint_vars = eval(blueprint_vars_executable["payload"])
82+
83+
for var, value in dict(blueprint_vars).items():
84+
lowercase_var = var.lower()
85+
if var != lowercase_var:
86+
if lowercase_var in blueprint_vars:
87+
get_console().log_warning(
88+
"SQLMesh is unable to fully migrate the state database, because the "
89+
f"model '{node['name']}' contains two blueprint variables ('{var}' and "
90+
f"'{lowercase_var}') that resolve to the same value ('{lowercase_var}'). "
91+
"This may result in unexpected changes being reported by the next "
92+
"`sqlmesh plan` command. If this happens, consider renaming either variable, "
93+
"so that the lowercase version of their names are different."
94+
)
95+
else:
96+
del blueprint_vars[var]
97+
blueprint_vars[lowercase_var] = value
98+
migrate_snapshot = True
99+
100+
if migrate_snapshot:
101+
migration_needed = True
102+
blueprint_vars_executable["payload"] = repr(blueprint_vars)
103+
104+
new_snapshots.append(
105+
{
106+
"name": name,
107+
"identifier": identifier,
108+
"version": version,
109+
"snapshot": json.dumps(parsed_snapshot),
110+
"kind_name": kind_name,
111+
"updated_ts": updated_ts,
112+
"unpaused_ts": unpaused_ts,
113+
"ttl_ms": ttl_ms,
114+
"unrestorable": unrestorable,
115+
}
116+
)
117+
118+
if migration_needed and new_snapshots:
119+
engine_adapter.delete_from(snapshots_table, "TRUE")
120+
121+
index_type = index_text_type(engine_adapter.dialect)
122+
blob_type = blob_text_type(engine_adapter.dialect)
123+
124+
engine_adapter.insert_append(
125+
snapshots_table,
126+
pd.DataFrame(new_snapshots),
127+
columns_to_types={
128+
"name": exp.DataType.build(index_type),
129+
"identifier": exp.DataType.build(index_type),
130+
"version": exp.DataType.build(index_type),
131+
"snapshot": exp.DataType.build(blob_type),
132+
"kind_name": exp.DataType.build("text"),
133+
"updated_ts": exp.DataType.build("bigint"),
134+
"unpaused_ts": exp.DataType.build("bigint"),
135+
"ttl_ms": exp.DataType.build("bigint"),
136+
"unrestorable": exp.DataType.build("boolean"),
137+
},
138+
)

tests/core/test_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9396,14 +9396,14 @@ def entrypoint(evaluator):
93969396
MODEL (
93979397
name @{customer}.my_table,
93989398
blueprints (
9399-
(customer := customer1, customer_field := 'bar'),
9400-
(customer := customer2, customer_field := qux),
9399+
(customer := customer1, Customer_Field := 'bar'),
9400+
(customer := customer2, Customer_Field := qux),
94019401
),
94029402
kind FULL
94039403
);
94049404
94059405
SELECT
9406-
@customer_field AS foo,
9406+
@customer_FIELD AS foo,
94079407
@{customer_field} AS foo2,
94089408
@BLUEPRINT_VAR('customer_field') AS foo3,
94099409
FROM @{customer}.my_source

0 commit comments

Comments
 (0)