Skip to content

Feat(vscode): Add the Table Diff view in the extension #4917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions sqlmesh/lsp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
CustomMethodRequestBaseClass,
CustomMethodResponseBaseClass,
)
from web.server.models import LineageColumn, Model
from web.server.models import LineageColumn, Model, TableDiff

API_FEATURE = "sqlmesh/api"

Expand All @@ -25,7 +25,7 @@ class ApiRequest(CustomMethodRequestBaseClass):
"""

requestId: str
url: str
endpoint: str
method: t.Optional[str] = "GET"
params: t.Optional[t.Dict[str, t.Any]] = None
body: t.Optional[t.Dict[str, t.Any]] = None
Expand Down Expand Up @@ -74,3 +74,11 @@ class ApiResponseGetColumnLineage(BaseAPIResponse):
"""

data: t.Dict[str, t.Dict[str, LineageColumn]]


class ApiResponseGetTableDiff(BaseAPIResponse):
"""
Response from the SQLMesh API for the get_table_diff endpoint.
"""

data: t.Optional[TableDiff]
61 changes: 61 additions & 0 deletions sqlmesh/lsp/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,17 @@ class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass):
pass


GET_ENVIRONMENTS_FEATURE = "sqlmesh/get_environments"


class GetEnvironmentsRequest(CustomMethodRequestBaseClass):
"""
Request to get all environments in the current project.
"""

pass


class TestEntry(PydanticModel):
"""
An entry representing a test in the workspace.
Expand Down Expand Up @@ -194,3 +205,53 @@ class RunTestRequest(CustomMethodRequestBaseClass):
class RunTestResponse(CustomMethodResponseBaseClass):
success: bool
error_message: t.Optional[str] = None


class EnvironmentInfo(PydanticModel):
"""
Information about an environment.
"""

name: str
snapshots: t.List[str]
start_at: str
plan_id: str


class GetEnvironmentsResponse(CustomMethodResponseBaseClass):
"""
Response containing all environments in the current project.
"""

environments: t.Dict[str, EnvironmentInfo]
pinned_environments: t.Set[str]
default_target_environment: str


GET_MODELS_FEATURE = "sqlmesh/get_models"


class GetModelsRequest(CustomMethodRequestBaseClass):
"""
Request to get all models available for table diff.
"""

pass


class ModelInfo(PydanticModel):
"""
Information about a model for table diff.
"""

name: str
fqn: str
description: t.Optional[str] = None


class GetModelsResponse(CustomMethodResponseBaseClass):
"""
Response containing all models available for table diff.
"""

models: t.List[ModelInfo]
149 changes: 146 additions & 3 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
WorkspaceInlayHintRefreshRequest,
)
from pygls.server import LanguageServer
from sqlglot import exp
from sqlmesh._version import __version__
from sqlmesh.core.context import Context
from sqlmesh.utils.date import to_timestamp
from sqlmesh.lsp.api import (
API_FEATURE,
ApiRequest,
ApiResponseGetColumnLineage,
ApiResponseGetLineage,
ApiResponseGetModels,
ApiResponseGetTableDiff,
)

from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS
Expand All @@ -36,6 +39,8 @@
RENDER_MODEL_FEATURE,
SUPPORTED_METHODS_FEATURE,
FORMAT_PROJECT_FEATURE,
GET_ENVIRONMENTS_FEATURE,
GET_MODELS_FEATURE,
AllModelsRequest,
AllModelsResponse,
AllModelsForRenderRequest,
Expand All @@ -57,6 +62,12 @@
RUN_TEST_FEATURE,
RunTestRequest,
RunTestResponse,
GetEnvironmentsRequest,
GetEnvironmentsResponse,
EnvironmentInfo,
GetModelsRequest,
GetModelsResponse,
ModelInfo,
)
from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic
from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position
Expand All @@ -74,9 +85,12 @@
from sqlmesh.utils.pydantic import PydanticModel
from web.server.api.endpoints.lineage import column_lineage, model_lineage
from web.server.api.endpoints.models import get_models
from web.server.api.endpoints.table_diff import _process_sample_data
from typing import Union
from dataclasses import dataclass, field

from web.server.models import RowDiff, SchemaDiff, TableDiff


class InitializationOptions(PydanticModel):
"""Initialization options for the SQLMesh Language Server, that
Expand Down Expand Up @@ -154,6 +168,8 @@ def __init__(
LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests,
LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests,
RUN_TEST_FEATURE: self._run_test,
GET_ENVIRONMENTS_FEATURE: self._custom_get_environments,
GET_MODELS_FEATURE: self._custom_get_models,
}

# Register LSP features (e.g., formatting, hover, etc.)
Expand Down Expand Up @@ -246,13 +262,71 @@ def _custom_format_project(
ls.log_trace(f"Error formatting project: {e}")
return FormatProjectResponse()

def _custom_get_environments(
self, ls: LanguageServer, params: GetEnvironmentsRequest
) -> GetEnvironmentsResponse:
"""Get all environments in the current project."""
try:
context = self._context_get_or_load()
environments = {}

# Get environments from state
for env in context.context.state_reader.get_environments():
environments[env.name] = EnvironmentInfo(
name=env.name,
snapshots=[s.fingerprint.to_identifier() for s in env.snapshots],
start_at=str(to_timestamp(env.start_at)),
plan_id=env.plan_id or "",
)

return GetEnvironmentsResponse(
environments=environments,
pinned_environments=context.context.config.pinned_environments,
default_target_environment=context.context.config.default_target_environment,
)
except Exception as e:
ls.log_trace(f"Error getting environments: {e}")
return GetEnvironmentsResponse(
response_error=str(e),
environments={},
pinned_environments=set(),
default_target_environment="",
)

def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse:
"""Get all models available for table diff."""
try:
context = self._context_get_or_load()
models = [
ModelInfo(
name=model.name,
fqn=model.fqn,
description=model.description,
)
for model in context.context.models.values()
# Filter for models that are suitable for table diff
if model._path is not None # Has a file path
]
return GetModelsResponse(models=models)
except Exception as e:
ls.log_trace(f"Error getting table diff models: {e}")
return GetModelsResponse(
response_error=str(e),
models=[],
)

def _custom_api(
self, ls: LanguageServer, request: ApiRequest
) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]:
) -> t.Union[
ApiResponseGetModels,
ApiResponseGetColumnLineage,
ApiResponseGetLineage,
ApiResponseGetTableDiff,
]:
ls.log_trace(f"API request: {request}")
context = self._context_get_or_load()

parsed_url = urllib.parse.urlparse(request.url)
parsed_url = urllib.parse.urlparse(request.endpoint)
path_parts = parsed_url.path.strip("/").split("/")

if request.method == "GET":
Expand Down Expand Up @@ -280,7 +354,76 @@ def _custom_api(
)
return ApiResponseGetColumnLineage(data=column_lineage_response)

raise NotImplementedError(f"API request not implemented: {request.url}")
if path_parts[:2] == ["api", "table_diff"]:
import numpy as np

# /api/table_diff
params = request.params
table_diff_result: t.Optional[TableDiff] = None
if params := request.params:
source = getattr(params, "source", "") if params else ""
target = getattr(params, "target", "") if params else ""
on = getattr(params, "on", None) if params else None
model_or_snapshot = (
getattr(params, "model_or_snapshot", None) if params else None
)
where = getattr(params, "where", None) if params else None
temp_schema = getattr(params, "temp_schema", None) if params else None
limit = getattr(params, "limit", 20) if params else 20

table_diffs = context.context.table_diff(
source=source,
target=target,
on=exp.condition(on) if on else None,
select_models={model_or_snapshot} if model_or_snapshot else None,
where=where,
limit=limit,
show=False,
)

if table_diffs:
diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs

_schema_diff = diff.schema_diff()
_row_diff = diff.row_diff(temp_schema=temp_schema)
schema_diff = SchemaDiff(
source=_schema_diff.source,
target=_schema_diff.target,
source_schema=_schema_diff.source_schema,
target_schema=_schema_diff.target_schema,
added=_schema_diff.added,
removed=_schema_diff.removed,
modified=_schema_diff.modified,
)

# create a readable column-centric sample data structure
processed_sample_data = _process_sample_data(_row_diff, source, target)

row_diff = RowDiff(
source=_row_diff.source,
target=_row_diff.target,
stats=_row_diff.stats,
sample=_row_diff.sample.replace({np.nan: None}).to_dict(),
joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(),
s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(),
t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(),
column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(),
source_count=_row_diff.source_count,
target_count=_row_diff.target_count,
count_pct_change=_row_diff.count_pct_change,
decimals=getattr(_row_diff, "decimals", 3),
processed_sample_data=processed_sample_data,
)

s_index, t_index, _ = diff.key_columns
table_diff_result = TableDiff(
schema_diff=schema_diff,
row_diff=row_diff,
on=[(s.name, t.name) for s, t in zip(s_index, t_index)],
)
return ApiResponseGetTableDiff(data=table_diff_result)

raise NotImplementedError(f"API request not implemented: {request.endpoint}")

def _custom_supported_methods(
self, ls: LanguageServer, params: SupportedMethodsRequest
Expand Down
45 changes: 44 additions & 1 deletion vscode/bus/src/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,56 @@ export type RPCMethods = {
}
api_query: {
params: {
url: string
endpoint: string
method: string
params: any
body: any
}
result: any
}
get_selected_model: {
params: {}
result: {
selectedModel?: any
}
}
get_all_models: {
params: {}
result: {
ok: boolean
models?: any[]
error?: string
}
}
set_selected_model: {
params: {
model: any
}
result: {
ok: boolean
selectedModel?: any
}
}
get_environments: {
params: {}
result: {
ok: boolean
environments?: Record<string, any>
error?: string
}
}
run_table_diff: {
params: {
sourceModel: string
sourceEnvironment: string
targetEnvironment: string
}
result: {
ok: boolean
data?: any
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't really have any

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep revised all these anys

error?: string
}
}
} & RPCMethodsShape

export type RPCRequest = {
Expand Down
3 changes: 3 additions & 0 deletions vscode/extension/assets/images/diff.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading