Skip to content

Commit d1dc55c

Browse files
authored
feat: ability to find a model block (#4933)
1 parent 81a9c0e commit d1dc55c

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

sqlmesh/core/linter/helpers.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from sqlmesh.core.linter.rule import Position, Range
44
from sqlmesh.utils.pydantic import PydanticModel
5+
from sqlglot import tokenize, TokenType
56
import typing as t
67

78

@@ -113,3 +114,41 @@ def read_range_from_file(file: Path, text_range: Range) -> str:
113114
result.append(line[start_char:end_char])
114115

115116
return "".join(result)
117+
118+
119+
def get_range_of_model_block(
120+
sql: str,
121+
dialect: str,
122+
) -> t.Optional[Range]:
123+
"""
124+
Get the range of the model block in an SQL file.
125+
"""
126+
tokens = tokenize(sql, dialect=dialect)
127+
128+
# Find start of the model block
129+
start = next(
130+
(t for t in tokens if t.token_type is TokenType.VAR and t.text.upper() == "MODEL"),
131+
None,
132+
)
133+
end = next((t for t in tokens if t.token_type is TokenType.SEMICOLON), None)
134+
135+
if start is None or end is None:
136+
return None
137+
138+
start_position = TokenPositionDetails(
139+
line=start.line,
140+
col=start.col,
141+
start=start.start,
142+
end=start.end,
143+
)
144+
end_position = TokenPositionDetails(
145+
line=end.line,
146+
col=end.col,
147+
start=end.start,
148+
end=end.end,
149+
)
150+
151+
splitlines = sql.splitlines()
152+
return Range(
153+
start=start_position.to_range(splitlines).start, end=end_position.to_range(splitlines).end
154+
)

sqlmesh/core/linter/rules/builtin.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlglot.expressions import Star
88
from sqlglot.helper import subclasses
99

10-
from sqlmesh.core.linter.helpers import TokenPositionDetails
10+
from sqlmesh.core.linter.helpers import TokenPositionDetails, get_range_of_model_block
1111
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit
1212
from sqlmesh.core.linter.definition import RuleSet
1313
from sqlmesh.core.model import Model, SqlModel
@@ -93,7 +93,21 @@ class NoMissingAudits(Rule):
9393
"""Model `audits` must be configured to test data quality."""
9494

9595
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
96-
return self.violation() if not model.audits and not model.kind.is_symbolic else None
96+
if model.audits or model.kind.is_symbolic:
97+
return None
98+
if model._path is None or not str(model._path).endswith(".sql"):
99+
return self.violation()
100+
101+
try:
102+
with open(model._path, "r", encoding="utf-8") as file:
103+
content = file.read()
104+
105+
range = get_range_of_model_block(content, model.dialect)
106+
if range:
107+
return self.violation(violation_range=range)
108+
return self.violation()
109+
except Exception:
110+
return self.violation()
97111

98112

99113
BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,)))

tests/core/linter/test_helpers.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from sqlmesh import Context
2+
from sqlmesh.core.linter.helpers import read_range_from_file, get_range_of_model_block
3+
from sqlmesh.core.model import SqlModel
4+
5+
6+
def test_get_position_of_model_block():
7+
context = Context(paths=["examples/sushi"])
8+
9+
sql_models = [
10+
model
11+
for model in context.models.values()
12+
if isinstance(model, SqlModel)
13+
and model._path is not None
14+
and str(model._path).endswith(".sql")
15+
]
16+
assert len(sql_models) > 0
17+
18+
for model in sql_models:
19+
dialect = model.dialect
20+
assert dialect is not None
21+
22+
path = model._path
23+
assert path is not None
24+
25+
with open(path, "r", encoding="utf-8") as file:
26+
content = file.read()
27+
28+
as_lines = content.splitlines()
29+
30+
range = get_range_of_model_block(content, dialect)
31+
assert range is not None
32+
33+
# Check that the range starts with MODEL and ends with ;
34+
read_range = read_range_from_file(path, range)
35+
assert read_range.startswith("MODEL")
36+
assert read_range.endswith(";")

0 commit comments

Comments
 (0)