|
14 | 14 | WorkspaceInlayHintRefreshRequest,
|
15 | 15 | )
|
16 | 16 | from pygls.server import LanguageServer
|
| 17 | +from sqlglot import exp |
17 | 18 | from sqlmesh._version import __version__
|
18 | 19 | from sqlmesh.core.context import Context
|
| 20 | +from sqlmesh.utils.date import to_timestamp |
19 | 21 | from sqlmesh.lsp.api import (
|
20 | 22 | API_FEATURE,
|
21 | 23 | ApiRequest,
|
22 | 24 | ApiResponseGetColumnLineage,
|
23 | 25 | ApiResponseGetLineage,
|
24 | 26 | ApiResponseGetModels,
|
| 27 | + ApiResponseGetTableDiff, |
25 | 28 | )
|
26 | 29 |
|
27 | 30 | from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
|
|
36 | 39 | RENDER_MODEL_FEATURE,
|
37 | 40 | SUPPORTED_METHODS_FEATURE,
|
38 | 41 | FORMAT_PROJECT_FEATURE,
|
| 42 | + GET_ENVIRONMENTS_FEATURE, |
| 43 | + GET_MODELS_FEATURE, |
39 | 44 | AllModelsRequest,
|
40 | 45 | AllModelsResponse,
|
41 | 46 | AllModelsForRenderRequest,
|
|
57 | 62 | RUN_TEST_FEATURE,
|
58 | 63 | RunTestRequest,
|
59 | 64 | RunTestResponse,
|
| 65 | + GetEnvironmentsRequest, |
| 66 | + GetEnvironmentsResponse, |
| 67 | + EnvironmentInfo, |
| 68 | + GetModelsRequest, |
| 69 | + GetModelsResponse, |
| 70 | + ModelInfo, |
60 | 71 | )
|
61 | 72 | from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
|
62 | 73 | from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
|
|
74 | 85 | from sqlmesh.utils.pydantic import PydanticModel
|
75 | 86 | from web.server.api.endpoints.lineage import column_lineage, model_lineage
|
76 | 87 | from web.server.api.endpoints.models import get_models
|
| 88 | +from web.server.api.endpoints.table_diff import _process_sample_data |
77 | 89 | from typing import Union
|
78 | 90 | from dataclasses import dataclass, field
|
79 | 91 |
|
| 92 | +from web.server.models import RowDiff, SchemaDiff, TableDiff |
| 93 | + |
80 | 94 |
|
81 | 95 | class InitializationOptions(PydanticModel):
|
82 | 96 | """Initialization options for the SQLMesh Language Server, that
|
@@ -154,6 +168,8 @@ def __init__(
|
154 | 168 | LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
|
155 | 169 | LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
|
156 | 170 | RUN_TEST_FEATURE: self._run_test,
|
| 171 | + GET_ENVIRONMENTS_FEATURE: self._custom_get_environments, |
| 172 | + GET_MODELS_FEATURE: self._custom_get_models, |
157 | 173 | }
|
158 | 174 |
|
159 | 175 | # Register LSP features (e.g., formatting, hover, etc.)
|
@@ -246,13 +262,71 @@ def _custom_format_project(
|
246 | 262 | ls.log_trace(f"Error formatting project: {e}")
|
247 | 263 | return FormatProjectResponse()
|
248 | 264 |
|
| 265 | + def _custom_get_environments( |
| 266 | + self, ls: LanguageServer, params: GetEnvironmentsRequest |
| 267 | + ) -> GetEnvironmentsResponse: |
| 268 | + """Get all environments in the current project.""" |
| 269 | + try: |
| 270 | + context = self._context_get_or_load() |
| 271 | + environments = {} |
| 272 | + |
| 273 | + # Get environments from state |
| 274 | + for env in context.context.state_reader.get_environments(): |
| 275 | + environments[env.name] = EnvironmentInfo( |
| 276 | + name=env.name, |
| 277 | + snapshots=[s.fingerprint.to_identifier() for s in env.snapshots], |
| 278 | + start_at=str(to_timestamp(env.start_at)), |
| 279 | + plan_id=env.plan_id or "", |
| 280 | + ) |
| 281 | + |
| 282 | + return GetEnvironmentsResponse( |
| 283 | + environments=environments, |
| 284 | + pinned_environments=context.context.config.pinned_environments, |
| 285 | + default_target_environment=context.context.config.default_target_environment, |
| 286 | + ) |
| 287 | + except Exception as e: |
| 288 | + ls.log_trace(f"Error getting environments: {e}") |
| 289 | + return GetEnvironmentsResponse( |
| 290 | + response_error=str(e), |
| 291 | + environments={}, |
| 292 | + pinned_environments=set(), |
| 293 | + default_target_environment="", |
| 294 | + ) |
| 295 | + |
| 296 | + def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse: |
| 297 | + """Get all models available for table diff.""" |
| 298 | + try: |
| 299 | + context = self._context_get_or_load() |
| 300 | + models = [ |
| 301 | + ModelInfo( |
| 302 | + name=model.name, |
| 303 | + fqn=model.fqn, |
| 304 | + description=model.description, |
| 305 | + ) |
| 306 | + for model in context.context.models.values() |
| 307 | + # Filter for models that are suitable for table diff |
| 308 | + if model._path is not None # Has a file path |
| 309 | + ] |
| 310 | + return GetModelsResponse(models=models) |
| 311 | + except Exception as e: |
| 312 | + ls.log_trace(f"Error getting table diff models: {e}") |
| 313 | + return GetModelsResponse( |
| 314 | + response_error=str(e), |
| 315 | + models=[], |
| 316 | + ) |
| 317 | + |
249 | 318 | def _custom_api(
|
250 | 319 | self, ls: LanguageServer, request: ApiRequest
|
251 |
| - ) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]: |
| 320 | + ) -> t.Union[ |
| 321 | + ApiResponseGetModels, |
| 322 | + ApiResponseGetColumnLineage, |
| 323 | + ApiResponseGetLineage, |
| 324 | + ApiResponseGetTableDiff, |
| 325 | + ]: |
252 | 326 | ls.log_trace(f"API request: {request}")
|
253 | 327 | context = self._context_get_or_load()
|
254 | 328 |
|
255 |
| - parsed_url = urllib.parse.urlparse(request.url) |
| 329 | + parsed_url = urllib.parse.urlparse(request.endpoint) |
256 | 330 | path_parts = parsed_url.path.strip("/").split("/")
|
257 | 331 |
|
258 | 332 | if request.method == "GET":
|
@@ -280,7 +354,76 @@ def _custom_api(
|
280 | 354 | )
|
281 | 355 | return ApiResponseGetColumnLineage(data=column_lineage_response)
|
282 | 356 |
|
283 |
| - raise NotImplementedError(f"API request not implemented: {request.url}") |
| 357 | + if path_parts[:2] == ["api", "table_diff"]: |
| 358 | + import numpy as np |
| 359 | + |
| 360 | + # /api/table_diff |
| 361 | + params = request.params |
| 362 | + table_diff_result: t.Optional[TableDiff] = None |
| 363 | + if params := request.params: |
| 364 | + source = getattr(params, "source", "") if params else "" |
| 365 | + target = getattr(params, "target", "") if params else "" |
| 366 | + on = getattr(params, "on", None) if params else None |
| 367 | + model_or_snapshot = ( |
| 368 | + getattr(params, "model_or_snapshot", None) if params else None |
| 369 | + ) |
| 370 | + where = getattr(params, "where", None) if params else None |
| 371 | + temp_schema = getattr(params, "temp_schema", None) if params else None |
| 372 | + limit = getattr(params, "limit", 20) if params else 20 |
| 373 | + |
| 374 | + table_diffs = context.context.table_diff( |
| 375 | + source=source, |
| 376 | + target=target, |
| 377 | + on=exp.condition(on) if on else None, |
| 378 | + select_models={model_or_snapshot} if model_or_snapshot else None, |
| 379 | + where=where, |
| 380 | + limit=limit, |
| 381 | + show=False, |
| 382 | + ) |
| 383 | + |
| 384 | + if table_diffs: |
| 385 | + diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs |
| 386 | + |
| 387 | + _schema_diff = diff.schema_diff() |
| 388 | + _row_diff = diff.row_diff(temp_schema=temp_schema) |
| 389 | + schema_diff = SchemaDiff( |
| 390 | + source=_schema_diff.source, |
| 391 | + target=_schema_diff.target, |
| 392 | + source_schema=_schema_diff.source_schema, |
| 393 | + target_schema=_schema_diff.target_schema, |
| 394 | + added=_schema_diff.added, |
| 395 | + removed=_schema_diff.removed, |
| 396 | + modified=_schema_diff.modified, |
| 397 | + ) |
| 398 | + |
| 399 | + # create a readable column-centric sample data structure |
| 400 | + processed_sample_data = _process_sample_data(_row_diff, source, target) |
| 401 | + |
| 402 | + row_diff = RowDiff( |
| 403 | + source=_row_diff.source, |
| 404 | + target=_row_diff.target, |
| 405 | + stats=_row_diff.stats, |
| 406 | + sample=_row_diff.sample.replace({np.nan: None}).to_dict(), |
| 407 | + joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(), |
| 408 | + s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(), |
| 409 | + t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(), |
| 410 | + column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(), |
| 411 | + source_count=_row_diff.source_count, |
| 412 | + target_count=_row_diff.target_count, |
| 413 | + count_pct_change=_row_diff.count_pct_change, |
| 414 | + decimals=getattr(_row_diff, "decimals", 3), |
| 415 | + processed_sample_data=processed_sample_data, |
| 416 | + ) |
| 417 | + |
| 418 | + s_index, t_index, _ = diff.key_columns |
| 419 | + table_diff_result = TableDiff( |
| 420 | + schema_diff=schema_diff, |
| 421 | + row_diff=row_diff, |
| 422 | + on=[(s.name, t.name) for s, t in zip(s_index, t_index)], |
| 423 | + ) |
| 424 | + return ApiResponseGetTableDiff(data=table_diff_result) |
| 425 | + |
| 426 | + raise NotImplementedError(f"API request not implemented: {request.endpoint}") |
284 | 427 |
|
285 | 428 | def _custom_supported_methods(
|
286 | 429 | self, ls: LanguageServer, params: SupportedMethodsRequest
|
|
0 commit comments