Skip to content

Commit 784f84f

Browse files
Feat(vscode): Add the Table Diff View in the extension
1 parent e4ea4c8 commit 784f84f

35 files changed

+3249
-33
lines changed

sqlmesh/lsp/api.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
CustomMethodRequestBaseClass,
1414
CustomMethodResponseBaseClass,
1515
)
16-
from web.server.models import LineageColumn, Model
16+
from web.server.models import LineageColumn, Model, TableDiff
1717

1818
API_FEATURE = "sqlmesh/api"
1919

@@ -25,7 +25,7 @@ class ApiRequest(CustomMethodRequestBaseClass):
2525
"""
2626

2727
requestId: str
28-
url: str
28+
endpoint: str
2929
method: t.Optional[str] = "GET"
3030
params: t.Optional[t.Dict[str, t.Any]] = None
3131
body: t.Optional[t.Dict[str, t.Any]] = None
@@ -74,3 +74,11 @@ class ApiResponseGetColumnLineage(BaseAPIResponse):
7474
"""
7575

7676
data: t.Dict[str, t.Dict[str, LineageColumn]]
77+
78+
79+
class ApiResponseGetTableDiff(BaseAPIResponse):
80+
"""
81+
Response from the SQLMesh API for the get_table_diff endpoint.
82+
"""
83+
84+
data: t.Optional[TableDiff]

sqlmesh/lsp/custom.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,18 @@ class FormatProjectResponse(CustomMethodResponseBaseClass):
149149

150150
LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests"
151151

152-
153152
class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
154153
"""
155154
Request to list all tests in the current project.
156155
"""
156+
pass
157+
158+
GET_ENVIRONMENTS_FEATURE = "sqlmesh/get_environments"
159+
160+
class GetEnvironmentsRequest(CustomMethodRequestBaseClass):
161+
"""
162+
Request to get all environments in the current project.
163+
"""
157164

158165
pass
159166

@@ -194,3 +201,52 @@ class RunTestRequest(CustomMethodRequestBaseClass):
194201
class RunTestResponse(CustomMethodResponseBaseClass):
195202
success: bool
196203
error_message: t.Optional[str] = None
204+
205+
class EnvironmentInfo(PydanticModel):
206+
"""
207+
Information about an environment.
208+
"""
209+
210+
name: str
211+
snapshots: t.List[str]
212+
start_at: str
213+
plan_id: str
214+
215+
216+
class GetEnvironmentsResponse(CustomMethodResponseBaseClass):
217+
"""
218+
Response containing all environments in the current project.
219+
"""
220+
221+
environments: t.Dict[str, EnvironmentInfo]
222+
pinned_environments: t.Set[str]
223+
default_target_environment: str
224+
225+
226+
GET_MODELS_FEATURE = "sqlmesh/get_models"
227+
228+
229+
class GetModelsRequest(CustomMethodRequestBaseClass):
230+
"""
231+
Request to get all models available for table diff.
232+
"""
233+
234+
pass
235+
236+
237+
class ModelInfo(PydanticModel):
238+
"""
239+
Information about a model for table diff.
240+
"""
241+
242+
name: str
243+
fqn: str
244+
description: t.Optional[str] = None
245+
246+
247+
class GetModelsResponse(CustomMethodResponseBaseClass):
248+
"""
249+
Response containing all models available for table diff.
250+
"""
251+
252+
models: t.List[ModelInfo]

sqlmesh/lsp/main.py

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
WorkspaceInlayHintRefreshRequest,
1515
)
1616
from pygls.server import LanguageServer
17+
from sqlglot import exp
1718
from sqlmesh._version import __version__
1819
from sqlmesh.core.context import Context
20+
from sqlmesh.utils.date import to_timestamp
1921
from sqlmesh.lsp.api import (
2022
API_FEATURE,
2123
ApiRequest,
2224
ApiResponseGetColumnLineage,
2325
ApiResponseGetLineage,
2426
ApiResponseGetModels,
27+
ApiResponseGetTableDiff,
2528
)
2629

2730
from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
@@ -36,6 +39,8 @@
3639
RENDER_MODEL_FEATURE,
3740
SUPPORTED_METHODS_FEATURE,
3841
FORMAT_PROJECT_FEATURE,
42+
GET_ENVIRONMENTS_FEATURE,
43+
GET_MODELS_FEATURE,
3944
AllModelsRequest,
4045
AllModelsResponse,
4146
AllModelsForRenderRequest,
@@ -57,6 +62,12 @@
5762
RUN_TEST_FEATURE,
5863
RunTestRequest,
5964
RunTestResponse,
65+
GetEnvironmentsRequest,
66+
GetEnvironmentsResponse,
67+
EnvironmentInfo,
68+
GetModelsRequest,
69+
GetModelsResponse,
70+
ModelInfo,
6071
)
6172
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
6273
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
@@ -73,9 +84,12 @@
7384
from sqlmesh.utils.lineage import ExternalModelReference
7485
from web.server.api.endpoints.lineage import column_lineage, model_lineage
7586
from web.server.api.endpoints.models import get_models
87+
from web.server.api.endpoints.table_diff import _process_sample_data
7688
from typing import Union
7789
from dataclasses import dataclass, field
7890

91+
from web.server.models import RowDiff, SchemaDiff, TableDiff
92+
7993

8094
@dataclass
8195
class NoContext:
@@ -141,6 +155,8 @@ def __init__(
141155
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
142156
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
143157
RUN_TEST_FEATURE: self._run_test,
158+
GET_ENVIRONMENTS_FEATURE: self._custom_get_environments,
159+
GET_MODELS_FEATURE: self._custom_get_models,
144160
}
145161

146162
# Register LSP features (e.g., formatting, hover, etc.)
@@ -233,13 +249,71 @@ def _custom_format_project(
233249
ls.log_trace(f"Error formatting project: {e}")
234250
return FormatProjectResponse()
235251

252+
def _custom_get_environments(
253+
self, ls: LanguageServer, params: GetEnvironmentsRequest
254+
) -> GetEnvironmentsResponse:
255+
"""Get all environments in the current project."""
256+
try:
257+
context = self._context_get_or_load()
258+
environments = {}
259+
260+
# Get environments from state
261+
for env in context.context.state_reader.get_environments():
262+
environments[env.name] = EnvironmentInfo(
263+
name=env.name,
264+
snapshots=[s.fingerprint.to_identifier() for s in env.snapshots],
265+
start_at=str(to_timestamp(env.start_at)),
266+
plan_id=env.plan_id or "",
267+
)
268+
269+
return GetEnvironmentsResponse(
270+
environments=environments,
271+
pinned_environments=context.context.config.pinned_environments,
272+
default_target_environment=context.context.config.default_target_environment,
273+
)
274+
except Exception as e:
275+
ls.log_trace(f"Error getting environments: {e}")
276+
return GetEnvironmentsResponse(
277+
response_error=str(e),
278+
environments={},
279+
pinned_environments=set(),
280+
default_target_environment="",
281+
)
282+
283+
def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse:
284+
"""Get all models available for table diff."""
285+
try:
286+
context = self._context_get_or_load()
287+
models = [
288+
ModelInfo(
289+
name=model.name,
290+
fqn=model.fqn,
291+
description=model.description,
292+
)
293+
for model in context.context.models.values()
294+
# Filter for models that are suitable for table diff
295+
if model._path is not None # Has a file path
296+
]
297+
return GetModelsResponse(models=models)
298+
except Exception as e:
299+
ls.log_trace(f"Error getting table diff models: {e}")
300+
return GetModelsResponse(
301+
response_error=str(e),
302+
models=[],
303+
)
304+
236305
def _custom_api(
237306
self, ls: LanguageServer, request: ApiRequest
238-
) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]:
307+
) -> t.Union[
308+
ApiResponseGetModels,
309+
ApiResponseGetColumnLineage,
310+
ApiResponseGetLineage,
311+
ApiResponseGetTableDiff,
312+
]:
239313
ls.log_trace(f"API request: {request}")
240314
context = self._context_get_or_load()
241315

242-
parsed_url = urllib.parse.urlparse(request.url)
316+
parsed_url = urllib.parse.urlparse(request.endpoint)
243317
path_parts = parsed_url.path.strip("/").split("/")
244318

245319
if request.method == "GET":
@@ -267,7 +341,76 @@ def _custom_api(
267341
)
268342
return ApiResponseGetColumnLineage(data=column_lineage_response)
269343

270-
raise NotImplementedError(f"API request not implemented: {request.url}")
344+
if path_parts[:2] == ["api", "table_diff"]:
345+
import numpy as np
346+
347+
# /api/table_diff
348+
params = request.params
349+
table_diff_result: t.Optional[TableDiff] = None
350+
if params := request.params:
351+
source = getattr(params, "source", "") if params else ""
352+
target = getattr(params, "target", "") if params else ""
353+
on = getattr(params, "on", None) if params else None
354+
model_or_snapshot = (
355+
getattr(params, "model_or_snapshot", None) if params else None
356+
)
357+
where = getattr(params, "where", None) if params else None
358+
temp_schema = getattr(params, "temp_schema", None) if params else None
359+
limit = getattr(params, "limit", 20) if params else 20
360+
361+
table_diffs = context.context.table_diff(
362+
source=source,
363+
target=target,
364+
on=exp.condition(on) if on else None,
365+
select_models={model_or_snapshot} if model_or_snapshot else None,
366+
where=where,
367+
limit=limit,
368+
show=False,
369+
)
370+
371+
if table_diffs:
372+
diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs
373+
374+
_schema_diff = diff.schema_diff()
375+
_row_diff = diff.row_diff(temp_schema=temp_schema)
376+
schema_diff = SchemaDiff(
377+
source=_schema_diff.source,
378+
target=_schema_diff.target,
379+
source_schema=_schema_diff.source_schema,
380+
target_schema=_schema_diff.target_schema,
381+
added=_schema_diff.added,
382+
removed=_schema_diff.removed,
383+
modified=_schema_diff.modified,
384+
)
385+
386+
# create a readable column-centric sample data structure
387+
processed_sample_data = _process_sample_data(_row_diff, source, target)
388+
389+
row_diff = RowDiff(
390+
source=_row_diff.source,
391+
target=_row_diff.target,
392+
stats=_row_diff.stats,
393+
sample=_row_diff.sample.replace({np.nan: None}).to_dict(),
394+
joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(),
395+
s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(),
396+
t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(),
397+
column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(),
398+
source_count=_row_diff.source_count,
399+
target_count=_row_diff.target_count,
400+
count_pct_change=_row_diff.count_pct_change,
401+
decimals=getattr(_row_diff, "decimals", 3),
402+
processed_sample_data=processed_sample_data,
403+
)
404+
405+
s_index, t_index, _ = diff.key_columns
406+
table_diff_result = TableDiff(
407+
schema_diff=schema_diff,
408+
row_diff=row_diff,
409+
on=[(s.name, t.name) for s, t in zip(s_index, t_index)],
410+
)
411+
return ApiResponseGetTableDiff(data=table_diff_result)
412+
413+
raise NotImplementedError(f"API request not implemented: {request.endpoint}")
271414

272415
def _custom_supported_methods(
273416
self, ls: LanguageServer, params: SupportedMethodsRequest

vscode/bus/src/callbacks.ts

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,56 @@ export type RPCMethods = {
5151
}
5252
api_query: {
5353
params: {
54-
url: string
54+
endpoint: string
5555
method: string
5656
params: any
5757
body: any
5858
}
5959
result: any
6060
}
61+
get_selected_model: {
62+
params: {}
63+
result: {
64+
selectedModel?: any
65+
}
66+
}
67+
get_all_models: {
68+
params: {}
69+
result: {
70+
ok: boolean
71+
models?: any[]
72+
error?: string
73+
}
74+
}
75+
set_selected_model: {
76+
params: {
77+
model: any
78+
}
79+
result: {
80+
ok: boolean
81+
selectedModel?: any
82+
}
83+
}
84+
get_environments: {
85+
params: {}
86+
result: {
87+
ok: boolean
88+
environments?: Record<string, any>
89+
error?: string
90+
}
91+
}
92+
run_table_diff: {
93+
params: {
94+
sourceModel: string
95+
sourceEnvironment: string
96+
targetEnvironment: string
97+
}
98+
result: {
99+
ok: boolean
100+
data?: any
101+
error?: string
102+
}
103+
}
61104
} & RPCMethodsShape
62105

63106
export type RPCRequest = {
Lines changed: 3 additions & 0 deletions
Loading

0 commit comments

Comments
 (0)