|
14 | 14 | WorkspaceInlayHintRefreshRequest,
|
15 | 15 | )
|
16 | 16 | from pygls.server import LanguageServer
|
| 17 | +from sqlglot import exp |
17 | 18 | from sqlmesh._version import __version__
|
18 | 19 | from sqlmesh.core.context import Context
|
| 20 | +from sqlmesh.utils.date import to_timestamp |
19 | 21 | from sqlmesh.lsp.api import (
|
20 | 22 | API_FEATURE,
|
21 | 23 | ApiRequest,
|
22 | 24 | ApiResponseGetColumnLineage,
|
23 | 25 | ApiResponseGetLineage,
|
24 | 26 | ApiResponseGetModels,
|
| 27 | + ApiResponseGetTableDiff, |
25 | 28 | )
|
26 | 29 |
|
27 | 30 | from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
|
|
36 | 39 | RENDER_MODEL_FEATURE,
|
37 | 40 | SUPPORTED_METHODS_FEATURE,
|
38 | 41 | FORMAT_PROJECT_FEATURE,
|
| 42 | + GET_ENVIRONMENTS_FEATURE, |
| 43 | + GET_MODELS_FEATURE, |
39 | 44 | AllModelsRequest,
|
40 | 45 | AllModelsResponse,
|
41 | 46 | AllModelsForRenderRequest,
|
|
57 | 62 | RUN_TEST_FEATURE,
|
58 | 63 | RunTestRequest,
|
59 | 64 | RunTestResponse,
|
| 65 | + GetEnvironmentsRequest, |
| 66 | + GetEnvironmentsResponse, |
| 67 | + EnvironmentInfo, |
| 68 | + GetModelsRequest, |
| 69 | + GetModelsResponse, |
| 70 | + ModelInfo, |
60 | 71 | )
|
61 | 72 | from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
|
62 | 73 | from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
|
|
73 | 84 | from sqlmesh.utils.lineage import ExternalModelReference
|
74 | 85 | from web.server.api.endpoints.lineage import column_lineage, model_lineage
|
75 | 86 | from web.server.api.endpoints.models import get_models
|
| 87 | +from web.server.api.endpoints.table_diff import _process_sample_data |
76 | 88 | from typing import Union
|
77 | 89 | from dataclasses import dataclass, field
|
78 | 90 |
|
| 91 | +from web.server.models import RowDiff, SchemaDiff, TableDiff |
| 92 | + |
79 | 93 |
|
80 | 94 | @dataclass
|
81 | 95 | class NoContext:
|
@@ -141,6 +155,8 @@ def __init__(
|
141 | 155 | LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
|
142 | 156 | LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
|
143 | 157 | RUN_TEST_FEATURE: self._run_test,
|
| 158 | + GET_ENVIRONMENTS_FEATURE: self._custom_get_environments, |
| 159 | + GET_MODELS_FEATURE: self._custom_get_models, |
144 | 160 | }
|
145 | 161 |
|
146 | 162 | # Register LSP features (e.g., formatting, hover, etc.)
|
@@ -233,13 +249,71 @@ def _custom_format_project(
|
233 | 249 | ls.log_trace(f"Error formatting project: {e}")
|
234 | 250 | return FormatProjectResponse()
|
235 | 251 |
|
| 252 | + def _custom_get_environments( |
| 253 | + self, ls: LanguageServer, params: GetEnvironmentsRequest |
| 254 | + ) -> GetEnvironmentsResponse: |
| 255 | + """Get all environments in the current project.""" |
| 256 | + try: |
| 257 | + context = self._context_get_or_load() |
| 258 | + environments = {} |
| 259 | + |
| 260 | + # Get environments from state |
| 261 | + for env in context.context.state_reader.get_environments(): |
| 262 | + environments[env.name] = EnvironmentInfo( |
| 263 | + name=env.name, |
| 264 | + snapshots=[s.fingerprint.to_identifier() for s in env.snapshots], |
| 265 | + start_at=str(to_timestamp(env.start_at)), |
| 266 | + plan_id=env.plan_id or "", |
| 267 | + ) |
| 268 | + |
| 269 | + return GetEnvironmentsResponse( |
| 270 | + environments=environments, |
| 271 | + pinned_environments=context.context.config.pinned_environments, |
| 272 | + default_target_environment=context.context.config.default_target_environment, |
| 273 | + ) |
| 274 | + except Exception as e: |
| 275 | + ls.log_trace(f"Error getting environments: {e}") |
| 276 | + return GetEnvironmentsResponse( |
| 277 | + response_error=str(e), |
| 278 | + environments={}, |
| 279 | + pinned_environments=set(), |
| 280 | + default_target_environment="", |
| 281 | + ) |
| 282 | + |
| 283 | + def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse: |
| 284 | + """Get all models available for table diff.""" |
| 285 | + try: |
| 286 | + context = self._context_get_or_load() |
| 287 | + models = [ |
| 288 | + ModelInfo( |
| 289 | + name=model.name, |
| 290 | + fqn=model.fqn, |
| 291 | + description=model.description, |
| 292 | + ) |
| 293 | + for model in context.context.models.values() |
| 294 | + # Filter for models that are suitable for table diff |
| 295 | + if model._path is not None # Has a file path |
| 296 | + ] |
| 297 | + return GetModelsResponse(models=models) |
| 298 | + except Exception as e: |
| 299 | + ls.log_trace(f"Error getting table diff models: {e}") |
| 300 | + return GetModelsResponse( |
| 301 | + response_error=str(e), |
| 302 | + models=[], |
| 303 | + ) |
| 304 | + |
236 | 305 | def _custom_api(
|
237 | 306 | self, ls: LanguageServer, request: ApiRequest
|
238 |
| - ) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]: |
| 307 | + ) -> t.Union[ |
| 308 | + ApiResponseGetModels, |
| 309 | + ApiResponseGetColumnLineage, |
| 310 | + ApiResponseGetLineage, |
| 311 | + ApiResponseGetTableDiff, |
| 312 | + ]: |
239 | 313 | ls.log_trace(f"API request: {request}")
|
240 | 314 | context = self._context_get_or_load()
|
241 | 315 |
|
242 |
| - parsed_url = urllib.parse.urlparse(request.url) |
| 316 | + parsed_url = urllib.parse.urlparse(request.endpoint) |
243 | 317 | path_parts = parsed_url.path.strip("/").split("/")
|
244 | 318 |
|
245 | 319 | if request.method == "GET":
|
@@ -267,7 +341,76 @@ def _custom_api(
|
267 | 341 | )
|
268 | 342 | return ApiResponseGetColumnLineage(data=column_lineage_response)
|
269 | 343 |
|
270 |
| - raise NotImplementedError(f"API request not implemented: {request.url}") |
| 344 | + if path_parts[:2] == ["api", "table_diff"]: |
| 345 | + import numpy as np |
| 346 | + |
| 347 | + # /api/table_diff |
| 348 | + params = request.params |
| 349 | + table_diff_result: t.Optional[TableDiff] = None |
| 350 | + if params := request.params: |
| 351 | + source = getattr(params, "source", "") if params else "" |
| 352 | + target = getattr(params, "target", "") if params else "" |
| 353 | + on = getattr(params, "on", None) if params else None |
| 354 | + model_or_snapshot = ( |
| 355 | + getattr(params, "model_or_snapshot", None) if params else None |
| 356 | + ) |
| 357 | + where = getattr(params, "where", None) if params else None |
| 358 | + temp_schema = getattr(params, "temp_schema", None) if params else None |
| 359 | + limit = getattr(params, "limit", 20) if params else 20 |
| 360 | + |
| 361 | + table_diffs = context.context.table_diff( |
| 362 | + source=source, |
| 363 | + target=target, |
| 364 | + on=exp.condition(on) if on else None, |
| 365 | + select_models={model_or_snapshot} if model_or_snapshot else None, |
| 366 | + where=where, |
| 367 | + limit=limit, |
| 368 | + show=False, |
| 369 | + ) |
| 370 | + |
| 371 | + if table_diffs: |
| 372 | + diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs |
| 373 | + |
| 374 | + _schema_diff = diff.schema_diff() |
| 375 | + _row_diff = diff.row_diff(temp_schema=temp_schema) |
| 376 | + schema_diff = SchemaDiff( |
| 377 | + source=_schema_diff.source, |
| 378 | + target=_schema_diff.target, |
| 379 | + source_schema=_schema_diff.source_schema, |
| 380 | + target_schema=_schema_diff.target_schema, |
| 381 | + added=_schema_diff.added, |
| 382 | + removed=_schema_diff.removed, |
| 383 | + modified=_schema_diff.modified, |
| 384 | + ) |
| 385 | + |
| 386 | + # create a readable column-centric sample data structure |
| 387 | + processed_sample_data = _process_sample_data(_row_diff, source, target) |
| 388 | + |
| 389 | + row_diff = RowDiff( |
| 390 | + source=_row_diff.source, |
| 391 | + target=_row_diff.target, |
| 392 | + stats=_row_diff.stats, |
| 393 | + sample=_row_diff.sample.replace({np.nan: None}).to_dict(), |
| 394 | + joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(), |
| 395 | + s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(), |
| 396 | + t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(), |
| 397 | + column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(), |
| 398 | + source_count=_row_diff.source_count, |
| 399 | + target_count=_row_diff.target_count, |
| 400 | + count_pct_change=_row_diff.count_pct_change, |
| 401 | + decimals=getattr(_row_diff, "decimals", 3), |
| 402 | + processed_sample_data=processed_sample_data, |
| 403 | + ) |
| 404 | + |
| 405 | + s_index, t_index, _ = diff.key_columns |
| 406 | + table_diff_result = TableDiff( |
| 407 | + schema_diff=schema_diff, |
| 408 | + row_diff=row_diff, |
| 409 | + on=[(s.name, t.name) for s, t in zip(s_index, t_index)], |
| 410 | + ) |
| 411 | + return ApiResponseGetTableDiff(data=table_diff_result) |
| 412 | + |
| 413 | + raise NotImplementedError(f"API request not implemented: {request.endpoint}") |
271 | 414 |
|
272 | 415 | def _custom_supported_methods(
|
273 | 416 | self, ls: LanguageServer, params: SupportedMethodsRequest
|
|
0 commit comments