Skip to content

Commit 0109a3a

Browse files
Feat(vscode): Add the Table Diff view in the extension (#4917)
1 parent 943e496 commit 0109a3a

35 files changed

+3264
-34
lines changed

sqlmesh/lsp/api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
CustomMethodRequestBaseClass,
1414
CustomMethodResponseBaseClass,
1515
)
16-
from web.server.models import LineageColumn, Model
16+
from web.server.models import LineageColumn, Model, TableDiff
1717

1818
API_FEATURE = "sqlmesh/api"
1919

@@ -25,7 +25,7 @@ class ApiRequest(CustomMethodRequestBaseClass):
2525
"""
2626

2727
requestId: str
28-
url: str
28+
endpoint: str
2929
method: t.Optional[str] = "GET"
3030
params: t.Optional[t.Dict[str, t.Any]] = None
3131
body: t.Optional[t.Dict[str, t.Any]] = None
@@ -74,3 +74,11 @@ class ApiResponseGetColumnLineage(BaseAPIResponse):
7474
"""
7575

7676
data: t.Dict[str, t.Dict[str, LineageColumn]]
77+
78+
79+
class ApiResponseGetTableDiff(BaseAPIResponse):
80+
"""
81+
Response from the SQLMesh API for the get_table_diff endpoint.
82+
"""
83+
84+
data: t.Optional[TableDiff]

sqlmesh/lsp/custom.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,17 @@ class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
158158
pass
159159

160160

161+
GET_ENVIRONMENTS_FEATURE = "sqlmesh/get_environments"
162+
163+
164+
class GetEnvironmentsRequest(CustomMethodRequestBaseClass):
165+
"""
166+
Request to get all environments in the current project.
167+
"""
168+
169+
pass
170+
171+
161172
class TestEntry(PydanticModel):
162173
"""
163174
An entry representing a test in the workspace.
@@ -194,3 +205,53 @@ class RunTestRequest(CustomMethodRequestBaseClass):
194205
class RunTestResponse(CustomMethodResponseBaseClass):
195206
success: bool
196207
error_message: t.Optional[str] = None
208+
209+
210+
class EnvironmentInfo(PydanticModel):
211+
"""
212+
Information about an environment.
213+
"""
214+
215+
name: str
216+
snapshots: t.List[str]
217+
start_at: str
218+
plan_id: str
219+
220+
221+
class GetEnvironmentsResponse(CustomMethodResponseBaseClass):
222+
"""
223+
Response containing all environments in the current project.
224+
"""
225+
226+
environments: t.Dict[str, EnvironmentInfo]
227+
pinned_environments: t.Set[str]
228+
default_target_environment: str
229+
230+
231+
GET_MODELS_FEATURE = "sqlmesh/get_models"
232+
233+
234+
class GetModelsRequest(CustomMethodRequestBaseClass):
235+
"""
236+
Request to get all models available for table diff.
237+
"""
238+
239+
pass
240+
241+
242+
class ModelInfo(PydanticModel):
243+
"""
244+
Information about a model for table diff.
245+
"""
246+
247+
name: str
248+
fqn: str
249+
description: t.Optional[str] = None
250+
251+
252+
class GetModelsResponse(CustomMethodResponseBaseClass):
253+
"""
254+
Response containing all models available for table diff.
255+
"""
256+
257+
models: t.List[ModelInfo]

sqlmesh/lsp/main.py

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
WorkspaceInlayHintRefreshRequest,
1515
)
1616
from pygls.server import LanguageServer
17+
from sqlglot import exp
1718
from sqlmesh._version import __version__
1819
from sqlmesh.core.context import Context
20+
from sqlmesh.utils.date import to_timestamp
1921
from sqlmesh.lsp.api import (
2022
API_FEATURE,
2123
ApiRequest,
2224
ApiResponseGetColumnLineage,
2325
ApiResponseGetLineage,
2426
ApiResponseGetModels,
27+
ApiResponseGetTableDiff,
2528
)
2629

2730
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
@@ -36,6 +39,8 @@
3639
RENDER_MODEL_FEATURE,
3740
SUPPORTED_METHODS_FEATURE,
3841
FORMAT_PROJECT_FEATURE,
42+
GET_ENVIRONMENTS_FEATURE,
43+
GET_MODELS_FEATURE,
3944
AllModelsRequest,
4045
AllModelsResponse,
4146
AllModelsForRenderRequest,
@@ -57,6 +62,12 @@
5762
RUN_TEST_FEATURE,
5863
RunTestRequest,
5964
RunTestResponse,
65+
GetEnvironmentsRequest,
66+
GetEnvironmentsResponse,
67+
EnvironmentInfo,
68+
GetModelsRequest,
69+
GetModelsResponse,
70+
ModelInfo,
6071
)
6172
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
6273
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
@@ -74,9 +85,12 @@
7485
from sqlmesh.utils.pydantic import PydanticModel
7586
from web.server.api.endpoints.lineage import column_lineage, model_lineage
7687
from web.server.api.endpoints.models import get_models
88+
from web.server.api.endpoints.table_diff import _process_sample_data
7789
from typing import Union
7890
from dataclasses import dataclass, field
7991

92+
from web.server.models import RowDiff, SchemaDiff, TableDiff
93+
8094

8195
class InitializationOptions(PydanticModel):
8296
"""Initialization options for the SQLMesh Language Server, that
@@ -154,6 +168,8 @@ def __init__(
154168
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
155169
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
156170
RUN_TEST_FEATURE: self._run_test,
171+
GET_ENVIRONMENTS_FEATURE: self._custom_get_environments,
172+
GET_MODELS_FEATURE: self._custom_get_models,
157173
}
158174

159175
# Register LSP features (e.g., formatting, hover, etc.)
@@ -246,13 +262,71 @@ def _custom_format_project(
246262
ls.log_trace(f"Error formatting project: {e}")
247263
return FormatProjectResponse()
248264

265+
def _custom_get_environments(
266+
self, ls: LanguageServer, params: GetEnvironmentsRequest
267+
) -> GetEnvironmentsResponse:
268+
"""Get all environments in the current project."""
269+
try:
270+
context = self._context_get_or_load()
271+
environments = {}
272+
273+
# Get environments from state
274+
for env in context.context.state_reader.get_environments():
275+
environments[env.name] = EnvironmentInfo(
276+
name=env.name,
277+
snapshots=[s.fingerprint.to_identifier() for s in env.snapshots],
278+
start_at=str(to_timestamp(env.start_at)),
279+
plan_id=env.plan_id or "",
280+
)
281+
282+
return GetEnvironmentsResponse(
283+
environments=environments,
284+
pinned_environments=context.context.config.pinned_environments,
285+
default_target_environment=context.context.config.default_target_environment,
286+
)
287+
except Exception as e:
288+
ls.log_trace(f"Error getting environments: {e}")
289+
return GetEnvironmentsResponse(
290+
response_error=str(e),
291+
environments={},
292+
pinned_environments=set(),
293+
default_target_environment="",
294+
)
295+
296+
def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse:
297+
"""Get all models available for table diff."""
298+
try:
299+
context = self._context_get_or_load()
300+
models = [
301+
ModelInfo(
302+
name=model.name,
303+
fqn=model.fqn,
304+
description=model.description,
305+
)
306+
for model in context.context.models.values()
307+
# Filter for models that are suitable for table diff
308+
if model._path is not None # Has a file path
309+
]
310+
return GetModelsResponse(models=models)
311+
except Exception as e:
312+
ls.log_trace(f"Error getting table diff models: {e}")
313+
return GetModelsResponse(
314+
response_error=str(e),
315+
models=[],
316+
)
317+
249318
def _custom_api(
250319
self, ls: LanguageServer, request: ApiRequest
251-
) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]:
320+
) -> t.Union[
321+
ApiResponseGetModels,
322+
ApiResponseGetColumnLineage,
323+
ApiResponseGetLineage,
324+
ApiResponseGetTableDiff,
325+
]:
252326
ls.log_trace(f"API request: {request}")
253327
context = self._context_get_or_load()
254328

255-
parsed_url = urllib.parse.urlparse(request.url)
329+
parsed_url = urllib.parse.urlparse(request.endpoint)
256330
path_parts = parsed_url.path.strip("/").split("/")
257331

258332
if request.method == "GET":
@@ -280,7 +354,76 @@ def _custom_api(
280354
)
281355
return ApiResponseGetColumnLineage(data=column_lineage_response)
282356

283-
raise NotImplementedError(f"API request not implemented: {request.url}")
357+
if path_parts[:2] == ["api", "table_diff"]:
358+
import numpy as np
359+
360+
# /api/table_diff
361+
params = request.params
362+
table_diff_result: t.Optional[TableDiff] = None
363+
if params := request.params:
364+
source = getattr(params, "source", "") if params else ""
365+
target = getattr(params, "target", "") if params else ""
366+
on = getattr(params, "on", None) if params else None
367+
model_or_snapshot = (
368+
getattr(params, "model_or_snapshot", None) if params else None
369+
)
370+
where = getattr(params, "where", None) if params else None
371+
temp_schema = getattr(params, "temp_schema", None) if params else None
372+
limit = getattr(params, "limit", 20) if params else 20
373+
374+
table_diffs = context.context.table_diff(
375+
source=source,
376+
target=target,
377+
on=exp.condition(on) if on else None,
378+
select_models={model_or_snapshot} if model_or_snapshot else None,
379+
where=where,
380+
limit=limit,
381+
show=False,
382+
)
383+
384+
if table_diffs:
385+
diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs
386+
387+
_schema_diff = diff.schema_diff()
388+
_row_diff = diff.row_diff(temp_schema=temp_schema)
389+
schema_diff = SchemaDiff(
390+
source=_schema_diff.source,
391+
target=_schema_diff.target,
392+
source_schema=_schema_diff.source_schema,
393+
target_schema=_schema_diff.target_schema,
394+
added=_schema_diff.added,
395+
removed=_schema_diff.removed,
396+
modified=_schema_diff.modified,
397+
)
398+
399+
# create a readable column-centric sample data structure
400+
processed_sample_data = _process_sample_data(_row_diff, source, target)
401+
402+
row_diff = RowDiff(
403+
source=_row_diff.source,
404+
target=_row_diff.target,
405+
stats=_row_diff.stats,
406+
sample=_row_diff.sample.replace({np.nan: None}).to_dict(),
407+
joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(),
408+
s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(),
409+
t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(),
410+
column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(),
411+
source_count=_row_diff.source_count,
412+
target_count=_row_diff.target_count,
413+
count_pct_change=_row_diff.count_pct_change,
414+
decimals=getattr(_row_diff, "decimals", 3),
415+
processed_sample_data=processed_sample_data,
416+
)
417+
418+
s_index, t_index, _ = diff.key_columns
419+
table_diff_result = TableDiff(
420+
schema_diff=schema_diff,
421+
row_diff=row_diff,
422+
on=[(s.name, t.name) for s, t in zip(s_index, t_index)],
423+
)
424+
return ApiResponseGetTableDiff(data=table_diff_result)
425+
426+
raise NotImplementedError(f"API request not implemented: {request.endpoint}")
284427

285428
def _custom_supported_methods(
286429
self, ls: LanguageServer, params: SupportedMethodsRequest

vscode/bus/src/callbacks.ts

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,56 @@ export type RPCMethods = {
5151
}
5252
api_query: {
5353
params: {
54-
url: string
54+
endpoint: string
5555
method: string
5656
params: any
5757
body: any
5858
}
5959
result: any
6060
}
61+
get_selected_model: {
62+
params: {}
63+
result: {
64+
selectedModel?: any
65+
}
66+
}
67+
get_all_models: {
68+
params: {}
69+
result: {
70+
ok: boolean
71+
models?: any[]
72+
error?: string
73+
}
74+
}
75+
set_selected_model: {
76+
params: {
77+
model: any
78+
}
79+
result: {
80+
ok: boolean
81+
selectedModel?: any
82+
}
83+
}
84+
get_environments: {
85+
params: {}
86+
result: {
87+
ok: boolean
88+
environments?: Record<string, any>
89+
error?: string
90+
}
91+
}
92+
run_table_diff: {
93+
params: {
94+
sourceModel: string
95+
sourceEnvironment: string
96+
targetEnvironment: string
97+
}
98+
result: {
99+
ok: boolean
100+
data?: any
101+
error?: string
102+
}
103+
}
61104
} & RPCMethodsShape
62105

63106
export type RPCRequest = {
Lines changed: 3 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)